diff --git a/mlir/include/mlir/Bytecode/BytecodeImplementation.h b/mlir/include/mlir/Bytecode/BytecodeImplementation.h --- a/mlir/include/mlir/Bytecode/BytecodeImplementation.h +++ b/mlir/include/mlir/Bytecode/BytecodeImplementation.h @@ -23,6 +23,17 @@ #include "llvm/ADT/Twine.h" namespace mlir { +//===--------------------------------------------------------------------===// +// Dialect Version Interface. +//===--------------------------------------------------------------------===// + +/// This class is used to represent the version of a dialect, for the purpose +/// of polymorphic destruction. +class DialectVersion { +public: + virtual ~DialectVersion() = default; +}; + //===----------------------------------------------------------------------===// // DialectBytecodeReader //===----------------------------------------------------------------------===// @@ -39,6 +50,10 @@ /// Emit an error to the reader. virtual InFlightDiagnostic emitError(const Twine &msg = {}) = 0; + /// Retrieve the dialect version by name if available. + virtual FailureOr + getDialectVersion(StringRef dialectName) const = 0; + /// Read out a list of elements, invoking the provided callback for each /// element. The callback function may be in any of the following forms: /// * LogicalResult(T &) @@ -261,17 +276,6 @@ virtual int64_t getBytecodeVersion() const = 0; }; -//===--------------------------------------------------------------------===// -// Dialect Version Interface. -//===--------------------------------------------------------------------===// - -/// This class is used to represent the version of a dialect, for the purpose -/// of polymorphic destruction. -class DialectVersion { -public: - virtual ~DialectVersion() = default; -}; - //===----------------------------------------------------------------------===// // BytecodeDialectInterface //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -2526,6 +2526,10 @@ // Whether this op has a folder. bit hasFolder = 0; + // Whether to let ops implement their custom `readProperties` and + // `writeProperties` methods to emit bytecode. + bit useCustomPropertiesEncoding = 0; + // Op traits. // Note: The list of traits will be uniqued by ODS. list traits = props; diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h --- a/mlir/include/mlir/TableGen/Operator.h +++ b/mlir/include/mlir/TableGen/Operator.h @@ -352,6 +352,10 @@ bool hasFolder() const; + /// Whether to generate the `readProperty`/`writeProperty` methods for + /// bytecode emission. + bool useCustomPropertiesEncoding() const; + private: /// Populates the vectors containing operands, attributes, results and traits. void populateOpStructure(); diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -311,7 +311,8 @@ // Parse in the remaining bytes of the value. llvm::support::ulittle64_t resultLE(result); - if (failed(parseBytes(numBytes, reinterpret_cast(&resultLE) + 1))) + if (failed( + parseBytes(numBytes, reinterpret_cast(&resultLE) + 1))) return failure(); // Shift out the low-order bits that were used to mark how the value was @@ -505,10 +506,11 @@ /// Parse a single dialect group encoded in the byte stream. static LogicalResult parseDialectGrouping( - EncodingReader &reader, MutableArrayRef dialects, + EncodingReader &reader, + MutableArrayRef> dialects, function_ref entryCallback) { // Parse the dialect and the number of entries in the group. - BytecodeDialect *dialect; + std::unique_ptr *dialect; if (failed(parseEntry(reader, dialects, dialect, "dialect"))) return failure(); uint64_t numEntries; @@ -516,7 +518,7 @@ return failure(); for (uint64_t i = 0; i < numEntries; ++i) - if (failed(entryCallback(dialect))) + if (failed(entryCallback(dialect->get()))) return failure(); return success(); } @@ -532,7 +534,7 @@ /// Initialize the resource section reader with the given section data. LogicalResult initialize(Location fileLoc, const ParserConfig &config, - MutableArrayRef dialects, + MutableArrayRef> dialects, StringSectionReader &stringReader, ArrayRef sectionData, ArrayRef offsetSectionData, DialectReader &dialectReader, const std::shared_ptr &bufferOwnerRef); @@ -682,7 +684,7 @@ LogicalResult ResourceSectionReader::initialize( Location fileLoc, const ParserConfig &config, - MutableArrayRef dialects, + MutableArrayRef> dialects, StringSectionReader &stringReader, ArrayRef sectionData, ArrayRef offsetSectionData, DialectReader &dialectReader, const std::shared_ptr &bufferOwnerRef) { @@ -731,19 +733,19 @@ // Read the dialect resources from the bytecode. MLIRContext *ctx = fileLoc->getContext(); while (!offsetReader.empty()) { - BytecodeDialect *dialect; + std::unique_ptr *dialect; if (failed(parseEntry(offsetReader, dialects, dialect, "dialect")) || - failed(dialect->load(dialectReader, ctx))) + failed((*dialect)->load(dialectReader, ctx))) return failure(); - Dialect *loadedDialect = dialect->getLoadedDialect(); + Dialect *loadedDialect = (*dialect)->getLoadedDialect(); if (!loadedDialect) { return resourceReader.emitError() - << "dialect '" << dialect->name << "' is unknown"; + << "dialect '" << (*dialect)->name << "' is unknown"; } const auto *handler = dyn_cast(loadedDialect); if (!handler) { return resourceReader.emitError() - << "unexpected resources for dialect '" << dialect->name << "'"; + << "unexpected resources for dialect '" << (*dialect)->name << "'"; } // Ensure that each resource is declared before being processed. @@ -753,7 +755,7 @@ if (failed(handle)) { return resourceReader.emitError() << "unknown 'resource' key '" << key << "' for dialect '" - << dialect->name << "'"; + << (*dialect)->name << "'"; } dialectResourceHandleRenamingMap[key] = handler->getResourceKey(*handle); dialectResources.push_back(*handle); @@ -796,14 +798,17 @@ public: AttrTypeReader(StringSectionReader &stringReader, - ResourceSectionReader &resourceReader, Location fileLoc) + ResourceSectionReader &resourceReader, + const llvm::StringMap &loadedDialectsMap, + Location fileLoc) : stringReader(stringReader), resourceReader(resourceReader), - fileLoc(fileLoc) {} + loadedDialectsMap(loadedDialectsMap), fileLoc(fileLoc) {} /// Initialize the attribute and type information within the reader. - LogicalResult initialize(MutableArrayRef dialects, - ArrayRef sectionData, - ArrayRef offsetSectionData); + LogicalResult + initialize(MutableArrayRef> dialects, + ArrayRef sectionData, + ArrayRef offsetSectionData); /// Resolve the attribute or type at the given index. Returns nullptr on /// failure. @@ -877,6 +882,10 @@ /// parsing custom encoded attribute/type entries. ResourceSectionReader &resourceReader; + /// The map of the loaded dialects used to retrieve dialect information, such + /// as the dialect version. + const llvm::StringMap &loadedDialectsMap; + /// The set of attribute and type entries. SmallVector attributes; SmallVector types; @@ -889,17 +898,29 @@ public: DialectReader(AttrTypeReader &attrTypeReader, StringSectionReader &stringReader, - ResourceSectionReader &resourceReader, EncodingReader &reader) + ResourceSectionReader &resourceReader, + const llvm::StringMap &loadedDialectsMap, + EncodingReader &reader) : attrTypeReader(attrTypeReader), stringReader(stringReader), - resourceReader(resourceReader), reader(reader) {} + resourceReader(resourceReader), loadedDialectsMap(loadedDialectsMap), + reader(reader) {} InFlightDiagnostic emitError(const Twine &msg) override { return reader.emitError(msg); } + FailureOr + getDialectVersion(StringRef dialectName) const override { + auto dialectEntry = loadedDialectsMap.find(dialectName); + if (dialectEntry == loadedDialectsMap.end() || + dialectEntry->getValue()->loadedVersion == nullptr) + return failure(); + return dialectEntry->getValue()->loadedVersion.get(); + } + DialectReader withEncodingReader(EncodingReader &encReader) { return DialectReader(attrTypeReader, stringReader, resourceReader, - encReader); + loadedDialectsMap, encReader); } Location getLoc() const { return reader.getLoc(); } @@ -1002,6 +1023,7 @@ AttrTypeReader &attrTypeReader; StringSectionReader &stringReader; ResourceSectionReader &resourceReader; + const llvm::StringMap &loadedDialectsMap; EncodingReader &reader; }; @@ -1067,8 +1089,7 @@ StringRef(rawProperties.begin(), rawProperties.size()), fileLoc); DialectReader propReader = dialectReader.withEncodingReader(reader); - auto *iface = opName->getInterface(); - if (iface) + if (auto *iface = opName->getInterface()) return iface->readProperties(propReader, opState); if (opName->isRegistered()) return propReader.emitError( @@ -1087,10 +1108,9 @@ }; } // namespace -LogicalResult -AttrTypeReader::initialize(MutableArrayRef dialects, - ArrayRef sectionData, - ArrayRef offsetSectionData) { +LogicalResult AttrTypeReader::initialize( + MutableArrayRef> dialects, + ArrayRef sectionData, ArrayRef offsetSectionData) { EncodingReader offsetReader(offsetSectionData, fileLoc); // Parse the number of attribute and type entries. @@ -1207,7 +1227,8 @@ LogicalResult AttrTypeReader::parseCustomEntry(Entry &entry, EncodingReader &reader, StringRef entryType) { - DialectReader dialectReader(*this, stringReader, resourceReader, reader); + DialectReader dialectReader(*this, stringReader, resourceReader, + loadedDialectsMap, reader); if (failed(entry.dialect->load(dialectReader, fileLoc.getContext()))) return failure(); // Ensure that the dialect implements the bytecode interface. @@ -1252,7 +1273,8 @@ llvm::MemoryBufferRef buffer, const std::shared_ptr &bufferOwnerRef) : config(config), fileLoc(fileLoc), lazyLoading(lazyLoading), - attrTypeReader(stringReader, resourceReader, fileLoc), + attrTypeReader(stringReader, resourceReader, loadedDialectsMap, + fileLoc), // Use the builtin unrealized conversion cast operation to represent // forward references to values that aren't yet defined. forwardRefOpState(UnknownLoc::get(config.getContext()), @@ -1518,7 +1540,8 @@ StringRef producer; /// The table of IR units referenced within the bytecode file. - SmallVector dialects; + SmallVector> dialects; + llvm::StringMap loadedDialectsMap; SmallVector opNames; /// The reader used to process resources within the bytecode. @@ -1709,13 +1732,15 @@ // Parse each of the dialects. for (uint64_t i = 0; i < numDialects; ++i) { + dialects[i] = std::make_unique(); /// Before version kDialectVersioning, there wasn't any versioning available /// for dialects, and the entryIdx represent the string itself. if (version < bytecode::kDialectVersioning) { - if (failed(stringReader.parseString(sectionReader, dialects[i].name))) + if (failed(stringReader.parseString(sectionReader, dialects[i]->name))) return failure(); continue; } + // Parse ID representing dialect and version. uint64_t dialectNameIdx; bool versionAvailable; @@ -1723,25 +1748,26 @@ versionAvailable))) return failure(); if (failed(stringReader.parseStringAtIndex(sectionReader, dialectNameIdx, - dialects[i].name))) + dialects[i]->name))) return failure(); if (versionAvailable) { bytecode::Section::ID sectionID; - if (failed( - sectionReader.parseSection(sectionID, dialects[i].versionBuffer))) + if (failed(sectionReader.parseSection(sectionID, + dialects[i]->versionBuffer))) return failure(); if (sectionID != bytecode::Section::kDialectVersions) { emitError(fileLoc, "expected dialect version section"); return failure(); } } + loadedDialectsMap[dialects[i]->name] = dialects[i].get(); } // Parse the operation names, which are grouped by dialect. auto parseOpName = [&](BytecodeDialect *dialect) { StringRef opName; std::optional wasRegistered; - // Prior to version kNativePropertiesEncoding, the information about wheter + // Prior to version kNativePropertiesEncoding, the information about whether // an op was registered or not wasn't encoded. if (version < bytecode::kNativePropertiesEncoding) { if (failed(stringReader.parseString(sectionReader, opName))) @@ -1782,7 +1808,7 @@ if (!opName->opName) { // Load the dialect and its version. DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, - reader); + loadedDialectsMap, reader); if (failed(opName->dialect->load(dialectReader, getContext()))) return failure(); // If the opName is empty, this is because we use to accept names such as @@ -1825,7 +1851,7 @@ // Initialize the resource reader with the resource sections. DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, - reader); + loadedDialectsMap, reader); return resourceReader.initialize(fileLoc, config, dialects, stringReader, *resourceData, *resourceOffsetData, dialectReader, bufferOwnerRef); @@ -2026,14 +2052,14 @@ "parsed use-list orders were invalid and could not be applied"); // Resolve dialect version. - for (const BytecodeDialect &byteCodeDialect : dialects) { + for (const std::unique_ptr &byteCodeDialect : dialects) { // Parsing is complete, give an opportunity to each dialect to visit the // IR and perform upgrades. - if (!byteCodeDialect.loadedVersion) + if (!byteCodeDialect->loadedVersion) continue; - if (byteCodeDialect.interface && - failed(byteCodeDialect.interface->upgradeFromVersion( - *moduleOp, *byteCodeDialect.loadedVersion))) + if (byteCodeDialect->interface && + failed(byteCodeDialect->interface->upgradeFromVersion( + *moduleOp, *byteCodeDialect->loadedVersion))) return failure(); } @@ -2144,7 +2170,6 @@ BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader, RegionReadState &readState, bool &isIsolatedFromAbove) { - // Parse the name of the operation. std::optional wasRegistered; FailureOr opName = parseOpName(reader, wasRegistered); if (failed(opName)) @@ -2186,7 +2211,7 @@ // interface and control the serialization. if (wasRegistered) { DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, - reader); + loadedDialectsMap, reader); if (failed( propertiesReader.read(fileLoc, dialectReader, &*opName, opState))) return failure(); diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp --- a/mlir/lib/TableGen/Operator.cpp +++ b/mlir/lib/TableGen/Operator.cpp @@ -854,3 +854,7 @@ } bool Operator::hasFolder() const { return def.getValueAsBit("hasFolder"); } + +bool Operator::useCustomPropertiesEncoding() const { + return def.getValueAsBit("useCustomPropertiesEncoding"); +} diff --git a/mlir/test/Bytecode/versioning/versioned-op-with-native-prop-1.12.mlirbc b/mlir/test/Bytecode/versioning/versioned-op-with-native-prop-1.12.mlirbc new file mode 100644 index 0000000000000000000000000000000000000000..0000000000000000000000000000000000000000 GIT binary patch literal 0 Hc$@ : () -> () +} + diff --git a/mlir/test/Bytecode/versioning/versioned_bytecode.mlir b/mlir/test/Bytecode/versioning/versioned_bytecode.mlir --- a/mlir/test/Bytecode/versioning/versioned_bytecode.mlir +++ b/mlir/test/Bytecode/versioning/versioned_bytecode.mlir @@ -4,19 +4,19 @@ // Test roundtrip //===--------------------------------------------------------------------===// -// RUN: mlir-opt %S/versioned-op-1.12.mlirbc -emit-bytecode \ +// RUN: mlir-opt %S/versioned-op-with-prop-1.12.mlirbc -emit-bytecode \ // RUN: -emit-bytecode-version=0 | mlir-opt -o %t.1 && \ -// RUN: mlir-opt %S/versioned-op-1.12.mlirbc -o %t.2 && \ +// RUN: mlir-opt %S/versioned-op-with-prop-1.12.mlirbc -o %t.2 && \ // RUN: diff %t.1 %t.2 //===--------------------------------------------------------------------===// // Test invalid versions //===--------------------------------------------------------------------===// -// RUN: not mlir-opt %S/versioned-op-1.12.mlirbc -emit-bytecode \ +// RUN: not mlir-opt %S/versioned-op-with-prop-1.12.mlirbc -emit-bytecode \ // RUN: -emit-bytecode-version=-1 2>&1 | FileCheck %s --check-prefix=ERR_VERSION_NEGATIVE // ERR_VERSION_NEGATIVE: unsupported version requested -1, must be in range [{{[0-9]+}}, {{[0-9]+}}] -// RUN: not mlir-opt %S/versioned-op-1.12.mlirbc -emit-bytecode \ +// RUN: not mlir-opt %S/versioned-op-with-prop-1.12.mlirbc -emit-bytecode \ // RUN: -emit-bytecode-version=999 2>&1 | FileCheck %s --check-prefix=ERR_VERSION_FUTURE // ERR_VERSION_FUTURE: unsupported version requested 999, must be in range [{{[0-9]+}}, {{[0-9]+}}] diff --git a/mlir/test/Bytecode/versioning/versioned_op.mlir b/mlir/test/Bytecode/versioning/versioned_op.mlir --- a/mlir/test/Bytecode/versioning/versioned_op.mlir +++ b/mlir/test/Bytecode/versioning/versioned_op.mlir @@ -1,5 +1,7 @@ // This file contains test cases related to the dialect post-parsing upgrade // mechanism. +// COM: those tests parse bytecode that was generated before test dialect +// adopted `usePropertiesFromAttributes`. //===--------------------------------------------------------------------===// // Test generic @@ -10,7 +12,7 @@ // COM: version: 2.0 // COM: "test.versionedA"() <{dims = 123 : i64, modifier = false}> : () -> () // COM: } -// RUN: mlir-opt %S/versioned-op-2.0.mlirbc 2>&1 | FileCheck %s --check-prefix=CHECK1 +// RUN: mlir-opt %S/versioned-op-with-prop-2.0.mlirbc 2>&1 | FileCheck %s --check-prefix=CHECK1 // CHECK1: "test.versionedA"() <{dims = 123 : i64, modifier = false}> : () -> () //===--------------------------------------------------------------------===// @@ -22,8 +24,8 @@ // COM: version: 1.12 // COM: "test.versionedA"() <{dimensions = 123 : i64}> : () -> () // COM: } -// RUN: mlir-opt %S/versioned-op-1.12.mlirbc 2>&1 | FileCheck %s --check-prefix=CHECK2 -// CHECK2: "test.versionedA"() <{dims = 123 : i64, modifier = false}> : () -> () +// RUN: mlir-opt %S/versioned-op-with-prop-1.12.mlirbc 2>&1 | FileCheck %s --check-prefix=CHECK3 +// CHECK3: "test.versionedA"() <{dims = 123 : i64, modifier = false}> : () -> () //===--------------------------------------------------------------------===// // Test forbidden downgrade diff --git a/mlir/test/Bytecode/versioning/versioned_op_with_native_properties.mlir b/mlir/test/Bytecode/versioning/versioned_op_with_native_properties.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Bytecode/versioning/versioned_op_with_native_properties.mlir @@ -0,0 +1,28 @@ +// This file contains test cases related to the dialect post-parsing upgrade +// mechanism. +// COM: those tests parse bytecode that was generated before test dialect +// adopted `usePropertiesFromAttributes`. + +//===--------------------------------------------------------------------===// +// Test generic +//===--------------------------------------------------------------------===// + +// COM: bytecode contains +// COM: module { +// COM: version: 2.0 +// COM: test.with_versioned_properties 1 | 2 +// COM: } +// RUN: mlir-opt %S/versioned-op-with-native-prop-2.0.mlirbc 2>&1 | FileCheck %s --check-prefix=CHECK1 +// CHECK1: test.with_versioned_properties 1 | 2 + +//===--------------------------------------------------------------------===// +// Test upgrade +//===--------------------------------------------------------------------===// + +// COM: bytecode contains +// COM: module { +// COM: version: 1.12 + +// COM: } +// RUN: mlir-opt %S/versioned-op-with-native-prop-1.12.mlirbc 2>&1 | FileCheck %s --check-prefix=CHECK3 +// CHECK3: test.with_versioned_properties 1 | 0 diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h --- a/mlir/test/lib/Dialect/Test/TestDialect.h +++ b/mlir/test/lib/Dialect/Test/TestDialect.h @@ -14,9 +14,9 @@ #ifndef MLIR_TESTDIALECT_H #define MLIR_TESTDIALECT_H -#include "TestTypes.h" #include "TestAttributes.h" #include "TestInterfaces.h" +#include "TestTypes.h" #include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/DLTI/Traits.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -84,6 +84,17 @@ return content == rhs.content; } }; +struct VersionedProperties { + // For the sake of testing, assume that this object was associated to version + // 1.2 of the test dialect when having only one int value. In the current + // version 2.0, the property has two values. We also assume that the class is + // upgrade-able if value2 = 0. + int value1; + int value2; + bool operator==(const VersionedProperties &rhs) const { + return value1 == rhs.value1 && value2 == rhs.value2; + } +}; } // namespace test #define GET_OP_CLASSES diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -114,6 +114,16 @@ const PropertiesWithCustomPrint &prop); static ParseResult customParseProperties(OpAsmParser &parser, PropertiesWithCustomPrint &prop); +static LogicalResult setPropertiesFromAttribute(VersionedProperties &prop, + Attribute attr, + InFlightDiagnostic *diagnostic); +static DictionaryAttr getPropertiesAsAttribute(MLIRContext *ctx, + const VersionedProperties &prop); +static llvm::hash_code computeHash(const VersionedProperties &prop); +static void customPrintProperties(OpAsmPrinter &p, + const VersionedProperties &prop); +static ParseResult customParseProperties(OpAsmParser &parser, + VersionedProperties &prop); void test::registerTestDialect(DialectRegistry ®istry) { registry.insert(); @@ -211,14 +221,11 @@ << "current test dialect version is 2.0, can't parse version: " << version.major << "." << version.minor; } - // Prior version 2.0, the old op supported only a single attribute called - // "dimensions". We can perform the upgrade. topLevelOp->walk([](TestVersionedOpA op) { - if (auto dims = op->getAttr("dimensions")) { - op->removeAttr("dimensions"); - op->setAttr("dims", dims); - } - op->setAttr("modifier", BoolAttr::get(op->getContext(), false)); + // Prior version 2.0, `readProperties` did not process the modifier + // attribute. Handle that according to the version here. + auto &prop = op.getProperties(); + prop.modifier = BoolAttr::get(op->getContext(), false); }); return success(); } @@ -1932,6 +1939,54 @@ prop.label = std::make_shared(std::move(label)); return success(); } +static LogicalResult +setPropertiesFromAttribute(VersionedProperties &prop, Attribute attr, + InFlightDiagnostic *diagnostic) { + DictionaryAttr dict = dyn_cast(attr); + if (!dict) { + if (diagnostic) + *diagnostic << "expected DictionaryAttr to set VersionedProperties"; + return failure(); + } + auto value1Attr = dict.getAs("value1"); + if (!value1Attr) { + if (diagnostic) + *diagnostic << "expected IntegerAttr for key `value1`"; + return failure(); + } + auto value2Attr = dict.getAs("value2"); + if (!value2Attr) { + if (diagnostic) + *diagnostic << "expected IntegerAttr for key `value2`"; + return failure(); + } + + prop.value1 = value1Attr.getValue().getSExtValue(); + prop.value2 = value2Attr.getValue().getSExtValue(); + return success(); +} +static DictionaryAttr +getPropertiesAsAttribute(MLIRContext *ctx, const VersionedProperties &prop) { + SmallVector attrs; + Builder b{ctx}; + attrs.push_back(b.getNamedAttr("value1", b.getI32IntegerAttr(prop.value1))); + attrs.push_back(b.getNamedAttr("value2", b.getI32IntegerAttr(prop.value2))); + return b.getDictionaryAttr(attrs); +} +static llvm::hash_code computeHash(const VersionedProperties &prop) { + return llvm::hash_combine(prop.value1, prop.value2); +} +static void customPrintProperties(OpAsmPrinter &p, + const VersionedProperties &prop) { + p << prop.value1 << " | " << prop.value2; +} +static ParseResult customParseProperties(OpAsmParser &parser, + VersionedProperties &prop) { + if (parser.parseInteger(prop.value1) || parser.parseVerticalBar() || + parser.parseInteger(prop.value2)) + return failure(); + return success(); +} static bool parseUsingPropertyInCustom(OpAsmParser &parser, int64_t value[3]) { return parser.parseLSquare() || parser.parseInteger(value[0]) || @@ -1945,6 +2000,69 @@ printer << '[' << value << ']'; } +LogicalResult +TestVersionedOpA::readProperties(::mlir::DialectBytecodeReader &reader, + ::mlir::OperationState &state) { + auto &prop = state.getOrAddProperties(); + if (::mlir::failed(reader.readAttribute(prop.dims))) + return ::mlir::failure(); + + // Check if we have a version. If not, assume we are parsing the current + // version. + auto maybeVersion = reader.getDialectVersion("test"); + if (succeeded(maybeVersion)) { + // If version is less than 2.0, there is no additional attribute to parse. + // We can materialize missing properties post parsing before verification. + const auto *version = + reinterpret_cast(*maybeVersion); + if ((version->major < 2)) { + return success(); + } + } + + if (::mlir::failed(reader.readAttribute(prop.modifier))) + return ::mlir::failure(); + return ::mlir::success(); +} + +void TestVersionedOpA::writeProperties(::mlir::DialectBytecodeWriter &writer) { + auto &prop = getProperties(); + writer.writeAttribute(prop.dims); + writer.writeAttribute(prop.modifier); +} + +::mlir::LogicalResult TestOpWithVersionedProperties::readFromMlirBytecode( + ::mlir::DialectBytecodeReader &reader, test::VersionedProperties &prop) { + uint64_t value1, value2 = 0; + if (failed(reader.readVarInt(value1))) + return failure(); + + // Check if we have a version. If not, assume we are parsing the current + // version. + auto maybeVersion = reader.getDialectVersion("test"); + bool needToParseAnotherInt = true; + if (succeeded(maybeVersion)) { + // If version is less than 2.0, there is no additional attribute to parse. + // We can materialize missing properties post parsing before verification. + const auto *version = + reinterpret_cast(*maybeVersion); + if ((version->major < 2)) + needToParseAnotherInt = false; + } + if (needToParseAnotherInt && failed(reader.readVarInt(value2))) + return failure(); + + prop.value1 = value1; + prop.value2 = value2; + return success(); +} +void TestOpWithVersionedProperties::writeToMlirBytecode( + ::mlir::DialectBytecodeWriter &writer, + const test::VersionedProperties &prop) { + writer.writeVarInt(prop.value1); + writer.writeVarInt(prop.value2); +} + #include "TestOpEnums.cpp.inc" #include "TestOpInterfaces.cpp.inc" #include "TestTypeInterfaces.cpp.inc" diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -3295,6 +3295,10 @@ AnyI64Attr:$dims, BoolAttr:$modifier ); + + // Since we use properties to store attributes, we need a custom encoding + // reader/writer to handle versioning. + let useCustomPropertiesEncoding = 1; } def TestVersionedOpB : TEST_Op<"versionedB"> { @@ -3414,6 +3418,51 @@ }]; } +def VersionedProperties : Property<"VersionedProperties"> { + let convertToAttribute = [{ + getPropertiesAsAttribute($_ctxt, $_storage) + }]; + let convertFromAttribute = [{ + return setPropertiesFromAttribute($_storage, $_attr, $_diag); + }]; + let hashProperty = [{ + computeHash($_storage); + }]; +} + +def TestOpWithVersionedProperties : TEST_Op<"with_versioned_properties"> { + let assemblyFormat = "prop-dict attr-dict"; + let arguments = (ins + VersionedProperties:$prop + ); + let extraClassDeclaration = [{ + void printProperties(::mlir::MLIRContext *ctx, ::mlir::OpAsmPrinter &p, + const Properties &prop); + static ::mlir::ParseResult parseProperties(::mlir::OpAsmParser &parser, + ::mlir::OperationState &result); + static ::mlir::LogicalResult readFromMlirBytecode( + ::mlir::DialectBytecodeReader &, + test::VersionedProperties &prop); + static void writeToMlirBytecode( + ::mlir::DialectBytecodeWriter &, + const test::VersionedProperties &prop); + }]; + let extraClassDefinition = [{ + void TestOpWithVersionedProperties::printProperties(::mlir::MLIRContext *ctx, + ::mlir::OpAsmPrinter &p, const Properties &prop) { + customPrintProperties(p, prop.prop); + } + ::mlir::ParseResult TestOpWithVersionedProperties::parseProperties( + ::mlir::OpAsmParser &parser, + ::mlir::OperationState &result) { + Properties &prop = result.getOrAddProperties(); + if (customParseProperties(parser, prop.prop)) + return failure(); + return success(); + } + }]; +} + #endif // TEST_OPS diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -420,6 +420,9 @@ namespace { // Helper class to emit a record into the given output stream. class OpEmitter { + using ConstArgument = + llvm::PointerUnion; + public: static void emitDecl(const Operator &op, raw_ostream &os, @@ -447,6 +450,9 @@ // Generates code to manage the properties, if any! void genPropertiesSupport(); + // Generates code to manage the encoding of properties to bytecode. + void genPropertiesSupportForBytecode(ArrayRef attrOrProperties); + // Generates getters for the attributes. void genAttrGetters(); @@ -1069,8 +1075,6 @@ void OpEmitter::genPropertiesSupport() { if (!emitHelper.hasProperties()) return; - using ConstArgument = - llvm::PointerUnion; SmallVector attrOrProperties; for (const std::pair &it : @@ -1131,21 +1135,6 @@ "getDiag")) ->body(); - auto &readPropertiesMethod = - opClass - .addStaticMethod( - "::mlir::LogicalResult", "readProperties", - MethodParameter("::mlir::DialectBytecodeReader &", "reader"), - MethodParameter("::mlir::OperationState &", "state")) - ->body(); - - auto &writePropertiesMethod = - opClass - .addMethod( - "void", "writeProperties", - MethodParameter("::mlir::DialectBytecodeWriter &", "writer")) - ->body(); - opClass.declare("Properties", "FoldAdaptor::Properties"); // Convert the property to the attribute form. @@ -1189,7 +1178,8 @@ .addSubst("_diag", propertyDiag)), name); } else { - const auto *namedAttr = llvm::dyn_cast_if_present(attrOrProp); + const auto *namedAttr = + llvm::dyn_cast_if_present(attrOrProp); StringRef name = namedAttr->attrName; setPropMethod << formatv(R"decl( {{ @@ -1242,7 +1232,8 @@ .addSubst("_storage", propertyStorage))); continue; } - const auto *namedAttr = llvm::dyn_cast_if_present(attrOrProp); + const auto *namedAttr = + llvm::dyn_cast_if_present(attrOrProp); StringRef name = namedAttr->attrName; getPropMethod << formatv(R"decl( {{ @@ -1325,7 +1316,8 @@ // syntax. This method verifies the constraint on the properties attributes // before they are set, since dyn_cast<> will silently omit failures. for (const auto &attrOrProp : attrOrProperties) { - const auto *namedAttr = llvm::dyn_cast_if_present(attrOrProp); + const auto *namedAttr = + llvm::dyn_cast_if_present(attrOrProp); if (!namedAttr || !namedAttr->constraint) continue; Attribute attr = *namedAttr->constraint; @@ -1349,6 +1341,38 @@ } verifyInherentAttrsMethod << " return ::mlir::success();"; + // Generate methods to interact with bytecode. + genPropertiesSupportForBytecode(attrOrProperties); +} + +void OpEmitter::genPropertiesSupportForBytecode( + ArrayRef attrOrProperties) { + if (op.useCustomPropertiesEncoding()) { + opClass.declareStaticMethod( + "::mlir::LogicalResult", "readProperties", + MethodParameter("::mlir::DialectBytecodeReader &", "reader"), + MethodParameter("::mlir::OperationState &", "state")); + opClass.declareMethod( + "void", "writeProperties", + MethodParameter("::mlir::DialectBytecodeWriter &", "writer")); + return; + } + + auto &readPropertiesMethod = + opClass + .addStaticMethod( + "::mlir::LogicalResult", "readProperties", + MethodParameter("::mlir::DialectBytecodeReader &", "reader"), + MethodParameter("::mlir::OperationState &", "state")) + ->body(); + + auto &writePropertiesMethod = + opClass + .addMethod( + "void", "writeProperties", + MethodParameter("::mlir::DialectBytecodeWriter &", "writer")) + ->body(); + // Populate bytecode serialization logic. readPropertiesMethod << " auto &prop = state.getOrAddProperties(); (void)prop;"; @@ -2576,7 +2600,8 @@ // Calculate the start index from which we can attach default values in the // builder declaration. for (int i = op.getNumArgs() - 1; i >= 0; --i) { - auto *namedAttr = llvm::dyn_cast_if_present(op.getArg(i)); + auto *namedAttr = + llvm::dyn_cast_if_present(op.getArg(i)); if (!namedAttr || !namedAttr->attr.hasDefaultValue()) break; @@ -2606,7 +2631,8 @@ for (int i = 0, e = op.getNumArgs(), numOperands = 0; i < e; ++i) { Argument arg = op.getArg(i); - if (const auto *operand = llvm::dyn_cast_if_present(arg)) { + if (const auto *operand = + llvm::dyn_cast_if_present(arg)) { StringRef type; if (operand->isVariadicOfVariadic()) type = "::llvm::ArrayRef<::mlir::ValueRange>"; @@ -3583,7 +3609,8 @@ .addSubst("_storage", propertyStorage))); continue; } - const auto *namedAttr = llvm::dyn_cast_if_present(attrOrProp); + const auto *namedAttr = + llvm::dyn_cast_if_present(attrOrProp); const Attribute *attr = nullptr; if (namedAttr->constraint) attr = &*namedAttr->constraint;