diff --git a/mlir/include/mlir/Bytecode/BytecodeImplementation.h b/mlir/include/mlir/Bytecode/BytecodeImplementation.h --- a/mlir/include/mlir/Bytecode/BytecodeImplementation.h +++ b/mlir/include/mlir/Bytecode/BytecodeImplementation.h @@ -39,6 +39,9 @@ /// Emit an error to the reader. virtual InFlightDiagnostic emitError(const Twine &msg = {}) = 0; + /// Return the bytecode version being read. + virtual uint64_t getBytecodeVersion() const = 0; + /// Read out a list of elements, invoking the provided callback for each /// element. The callback function may be in any of the following forms: /// * LogicalResult(T &) @@ -68,6 +71,34 @@ return success(); } + /// Read out a list of elements, invoking the provided callback for each + /// element. The callback function may be in any of the following forms: + /// * LogicalResult(T &) + /// * FailureOr() + template + LogicalResult readList(T (&result)[N], CallbackFn &&callback) { + uint64_t size; + if (failed(readVarInt(size))) + return failure(); + if (size != N) + emitError("Expected list of ") << N << " elements but has " << size; + + for (T &elt : result) { + // Check if the callback uses FailureOr, or populates the result by + // reference. + if constexpr (llvm::function_traits>::num_args) { + if (failed(callback(elt))) + return failure(); + } else { + FailureOr parsedElt = callback(); + if (failed(parsedElt)) + return failure(); + elt = std::move(*parsedElt); + } + } + return success(); + } + //===--------------------------------------------------------------------===// // IR //===--------------------------------------------------------------------===// @@ -196,6 +227,15 @@ callback(element); } + /// Write out an array of elements, invoking the provided callback for each + /// element. + template + void writeList(T (&range)[N], CallbackFn &&callback) { + writeVarInt(N); + for (auto &element : range) + callback(element); + } + /// Write a reference to the given attribute. virtual void writeAttribute(Attribute attr) = 0; virtual void writeOptionalAttribute(Attribute attr) = 0; diff --git a/mlir/include/mlir/Bytecode/Encoding.h b/mlir/include/mlir/Bytecode/Encoding.h --- a/mlir/include/mlir/Bytecode/Encoding.h +++ b/mlir/include/mlir/Bytecode/Encoding.h @@ -45,8 +45,12 @@ /// with the discardable attributes. kNativePropertiesEncoding = 5, + /// ODS emits operand/result segment_size as native properties instead of + /// an attribute. + kNativePropertiesODSSegmentSize = 6, + /// The current bytecode version. - kVersion = 5, + kVersion = 6, /// An arbitrary value used to fill alignment padding. kAlignmentByte = 0xCB, diff --git a/mlir/include/mlir/IR/ODSSupport.h b/mlir/include/mlir/IR/ODSSupport.h --- a/mlir/include/mlir/IR/ODSSupport.h +++ b/mlir/include/mlir/IR/ODSSupport.h @@ -37,9 +37,16 @@ LogicalResult convertFromAttribute(MutableArrayRef storage, Attribute attr, InFlightDiagnostic *diag); +/// Convert a DenseI32ArrayAttr to the provided storage. It is expected that the +/// storage has the same size as the array. An error is returned if the +/// attribute isn't a DenseI32ArrayAttr or it does not have the same size. If +/// the optional diagnostic is provided an error message is also emitted. +LogicalResult convertFromAttribute(MutableArrayRef storage, + Attribute attr, InFlightDiagnostic *diag); + /// Convert the provided ArrayRef to a DenseI64ArrayAttr attribute. Attribute convertToAttribute(MLIRContext *ctx, ArrayRef storage); } // namespace mlir #endif // MLIR_IR_ODSSUPPORT_H \ No newline at end of file diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1241,6 +1241,7 @@ let interfaceType = "::llvm::ArrayRef<" # storageTypeParam # ">"; let convertFromStorage = "$_storage"; let assignToStorage = "::llvm::copy($_value, $_storage)"; + let hashProperty = "llvm::hash_combine_range(std::begin($_storage), std::end($_storage));"; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -20,6 +20,7 @@ #define MLIR_IR_OPDEFINITION_H #include "mlir/IR/Dialect.h" +#include "mlir/IR/ODSSupport.h" #include "mlir/IR/Operation.h" #include "llvm/Support/PointerLikeTypeTraits.h" diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -552,7 +552,8 @@ StringRef name) final { if constexpr (hasProperties) { auto concreteOp = cast(op); - return ConcreteOp::getInherentAttr(concreteOp.getProperties(), name); + return ConcreteOp::getInherentAttr(concreteOp.getContext(), + concreteOp.getProperties(), name); } // If the op does not have support for properties, we dispatch back to the // dictionnary of discardable attributes for now. @@ -572,7 +573,8 @@ void populateInherentAttrs(Operation *op, NamedAttrList &attrs) final { if constexpr (hasProperties) { auto concreteOp = cast(op); - ConcreteOp::populateInherentAttrs(concreteOp.getProperties(), attrs); + ConcreteOp::populateInherentAttrs(concreteOp.getContext(), + concreteOp.getProperties(), attrs); } } LogicalResult diff --git a/mlir/include/mlir/TableGen/Property.h b/mlir/include/mlir/TableGen/Property.h --- a/mlir/include/mlir/TableGen/Property.h +++ b/mlir/include/mlir/TableGen/Property.h @@ -35,51 +35,76 @@ public: explicit Property(const llvm::Record *record); explicit Property(const llvm::DefInit *init); + Property(StringRef storageType, StringRef interfaceType, + StringRef convertFromStorageCall, StringRef assignToStorageCall, + StringRef convertToAttributeCall, StringRef convertFromAttributeCall, + StringRef readFromMlirBytecodeCall, + StringRef writeToMlirBytecodeCall, StringRef hashPropertyCall, + StringRef defaultValue); // Returns the storage type. - StringRef getStorageType() const; + StringRef getStorageType() const { return storageType; } // Returns the interface type for this property. - StringRef getInterfaceType() const; + StringRef getInterfaceType() const { return interfaceType; } // Returns the template getter method call which reads this property's // storage and returns the value as of the desired return type. - StringRef getConvertFromStorageCall() const; + StringRef getConvertFromStorageCall() const { return convertFromStorageCall; } // Returns the template setter method call which reads this property's // in the provided interface type and assign it to the storage. - StringRef getAssignToStorageCall() const; + StringRef getAssignToStorageCall() const { return assignToStorageCall; } // Returns the conversion method call which reads this property's // in the storage type and builds an attribute. - StringRef getConvertToAttributeCall() const; + StringRef getConvertToAttributeCall() const { return convertToAttributeCall; } // Returns the setter method call which reads this property's // in the provided interface type and assign it to the storage. - StringRef getConvertFromAttributeCall() const; + StringRef getConvertFromAttributeCall() const { + return convertFromAttributeCall; + } // Returns the method call which reads this property from // bytecode and assign it to the storage. - StringRef getReadFromMlirBytecodeCall() const; + StringRef getReadFromMlirBytecodeCall() const { + return readFromMlirBytecodeCall; + } // Returns the method call which write this property's // to the the bytecode. - StringRef getWriteToMlirBytecodeCall() const; + StringRef getWriteToMlirBytecodeCall() const { + return writeToMlirBytecodeCall; + } // Returns the code to compute the hash for this property. - StringRef getHashPropertyCall() const; + StringRef getHashPropertyCall() const { return hashPropertyCall; } // Returns whether this Property has a default value. - bool hasDefaultValue() const; + bool hasDefaultValue() const { return !defaultValue.empty(); } + // Returns the default value for this Property. - StringRef getDefaultValue() const; + StringRef getDefaultValue() const { return defaultValue; } // Returns the TableGen definition this Property was constructed from. - const llvm::Record &getDef() const; + const llvm::Record &getDef() const { return *def; } private: // The TableGen definition of this constraint. const llvm::Record *def; + + // Elements describing a Property, in general fetched from the record. + StringRef storageType; + StringRef interfaceType; + StringRef convertFromStorageCall; + StringRef assignToStorageCall; + StringRef convertToAttributeCall; + StringRef convertFromAttributeCall; + StringRef readFromMlirBytecodeCall; + StringRef writeToMlirBytecodeCall; + StringRef hashPropertyCall; + StringRef defaultValue; }; // A struct wrapping an op property and its name together diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -30,6 +30,7 @@ #include "llvm/Support/SaveAndRestore.h" #include "llvm/Support/SourceMgr.h" #include +#include #include #include #include @@ -796,9 +797,10 @@ public: AttrTypeReader(StringSectionReader &stringReader, - ResourceSectionReader &resourceReader, Location fileLoc) + ResourceSectionReader &resourceReader, Location fileLoc, + uint64_t &bytecodeVersion) : stringReader(stringReader), resourceReader(resourceReader), - fileLoc(fileLoc) {} + fileLoc(fileLoc), bytecodeVersion(bytecodeVersion) {} /// Initialize the attribute and type information within the reader. LogicalResult initialize(MutableArrayRef dialects, @@ -883,23 +885,30 @@ /// A location used for error emission. Location fileLoc; + + /// Current bytecode version being used. + uint64_t &bytecodeVersion; }; class DialectReader : public DialectBytecodeReader { public: DialectReader(AttrTypeReader &attrTypeReader, StringSectionReader &stringReader, - ResourceSectionReader &resourceReader, EncodingReader &reader) + ResourceSectionReader &resourceReader, EncodingReader &reader, + uint64_t &bytecodeVersion) : attrTypeReader(attrTypeReader), stringReader(stringReader), - resourceReader(resourceReader), reader(reader) {} + resourceReader(resourceReader), reader(reader), + bytecodeVersion(bytecodeVersion) {} InFlightDiagnostic emitError(const Twine &msg) override { return reader.emitError(msg); } + uint64_t getBytecodeVersion() const override { return bytecodeVersion; } + DialectReader withEncodingReader(EncodingReader &encReader) { return DialectReader(attrTypeReader, stringReader, resourceReader, - encReader); + encReader, bytecodeVersion); } Location getLoc() const { return reader.getLoc(); } @@ -1003,6 +1012,7 @@ StringSectionReader &stringReader; ResourceSectionReader &resourceReader; EncodingReader &reader; + uint64_t &bytecodeVersion; }; /// Wraps the properties section and handles reading properties out of it. @@ -1207,7 +1217,8 @@ LogicalResult AttrTypeReader::parseCustomEntry(Entry &entry, EncodingReader &reader, StringRef entryType) { - DialectReader dialectReader(*this, stringReader, resourceReader, reader); + DialectReader dialectReader(*this, stringReader, resourceReader, reader, + bytecodeVersion); if (failed(entry.dialect->load(dialectReader, fileLoc.getContext()))) return failure(); // Ensure that the dialect implements the bytecode interface. @@ -1252,7 +1263,7 @@ llvm::MemoryBufferRef buffer, const std::shared_ptr &bufferOwnerRef) : config(config), fileLoc(fileLoc), lazyLoading(lazyLoading), - attrTypeReader(stringReader, resourceReader, fileLoc), + attrTypeReader(stringReader, resourceReader, fileLoc, version), // Use the builtin unrealized conversion cast operation to represent // forward references to values that aren't yet defined. forwardRefOpState(UnknownLoc::get(config.getContext()), @@ -1782,7 +1793,7 @@ if (!opName->opName) { // Load the dialect and its version. DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, - reader); + reader, version); if (failed(opName->dialect->load(dialectReader, getContext()))) return failure(); // If the opName is empty, this is because we use to accept names such as @@ -1825,7 +1836,7 @@ // Initialize the resource reader with the resource sections. DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, - reader); + reader, version); return resourceReader.initialize(fileLoc, config, dialects, stringReader, *resourceData, *resourceOffsetData, dialectReader, bufferOwnerRef); @@ -2186,7 +2197,7 @@ // interface and control the serialization. if (wasRegistered) { DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, - reader); + reader, version); if (failed( propertiesReader.read(fileLoc, dialectReader, &*opName, opState))) return failure(); diff --git a/mlir/lib/IR/ODSSupport.cpp b/mlir/lib/IR/ODSSupport.cpp --- a/mlir/lib/IR/ODSSupport.cpp +++ b/mlir/lib/IR/ODSSupport.cpp @@ -33,13 +33,16 @@ Attribute mlir::convertToAttribute(MLIRContext *ctx, int64_t storage) { return IntegerAttr::get(IntegerType::get(ctx, 64), storage); } -LogicalResult mlir::convertFromAttribute(MutableArrayRef storage, - ::mlir::Attribute attr, - ::mlir::InFlightDiagnostic *diag) { - auto valueAttr = dyn_cast(attr); + +template +LogicalResult convertDenseArrayFromAttr(MutableArrayRef storage, + ::mlir::Attribute attr, + ::mlir::InFlightDiagnostic *diag, + StringRef denseArrayTyStr) { + auto valueAttr = dyn_cast(attr); if (!valueAttr) { if (diag) - *diag << "expected DenseI64ArrayAttr for key `value`"; + *diag << "expected " << denseArrayTyStr << " for key `value`"; return failure(); } if (valueAttr.size() != static_cast(storage.size())) { @@ -51,6 +54,19 @@ llvm::copy(valueAttr.asArrayRef(), storage.begin()); return success(); } +LogicalResult mlir::convertFromAttribute(MutableArrayRef storage, + ::mlir::Attribute attr, + ::mlir::InFlightDiagnostic *diag) { + return convertDenseArrayFromAttr(storage, attr, diag, + "DenseI64ArrayAttr"); +} +LogicalResult mlir::convertFromAttribute(MutableArrayRef storage, + Attribute attr, + InFlightDiagnostic *diag) { + return convertDenseArrayFromAttr(storage, attr, diag, + "DenseI32ArrayAttr"); +} + Attribute mlir::convertToAttribute(MLIRContext *ctx, ArrayRef storage) { return DenseI64ArrayAttr::get(ctx, storage); diff --git a/mlir/lib/TableGen/Property.cpp b/mlir/lib/TableGen/Property.cpp --- a/mlir/lib/TableGen/Property.cpp +++ b/mlir/lib/TableGen/Property.cpp @@ -32,65 +32,40 @@ return {}; } -Property::Property(const Record *record) : def(record) { - assert((record->isSubClassOf("Property") || record->isSubClassOf("Attr")) && +Property::Property(const Record *def) + : Property(getValueAsString(def->getValueInit("storageType")), + getValueAsString(def->getValueInit("interfaceType")), + getValueAsString(def->getValueInit("convertFromStorage")), + getValueAsString(def->getValueInit("assignToStorage")), + getValueAsString(def->getValueInit("convertToAttribute")), + getValueAsString(def->getValueInit("convertFromAttribute")), + getValueAsString(def->getValueInit("readFromMlirBytecode")), + getValueAsString(def->getValueInit("writeToMlirBytecode")), + getValueAsString(def->getValueInit("hashProperty")), + getValueAsString(def->getValueInit("defaultValue"))) { + this->def = def; + assert((def->isSubClassOf("Property") || def->isSubClassOf("Attr")) && "must be subclass of TableGen 'Property' class"); } Property::Property(const DefInit *init) : Property(init->getDef()) {} -StringRef Property::getStorageType() const { - const auto *init = def->getValueInit("storageType"); - auto type = getValueAsString(init); - if (type.empty()) - return "Property"; - return type; +Property::Property(StringRef storageType, StringRef interfaceType, + StringRef convertFromStorageCall, + StringRef assignToStorageCall, + StringRef convertToAttributeCall, + StringRef convertFromAttributeCall, + StringRef readFromMlirBytecodeCall, + StringRef writeToMlirBytecodeCall, + StringRef hashPropertyCall, StringRef defaultValue) + : storageType(storageType), interfaceType(interfaceType), + convertFromStorageCall(convertFromStorageCall), + assignToStorageCall(assignToStorageCall), + convertToAttributeCall(convertToAttributeCall), + convertFromAttributeCall(convertFromAttributeCall), + readFromMlirBytecodeCall(readFromMlirBytecodeCall), + writeToMlirBytecodeCall(writeToMlirBytecodeCall), + hashPropertyCall(hashPropertyCall), defaultValue(defaultValue) { + if (storageType.empty()) + storageType = "Property"; } - -StringRef Property::getInterfaceType() const { - const auto *init = def->getValueInit("interfaceType"); - return getValueAsString(init); -} - -StringRef Property::getConvertFromStorageCall() const { - const auto *init = def->getValueInit("convertFromStorage"); - return getValueAsString(init); -} - -StringRef Property::getAssignToStorageCall() const { - const auto *init = def->getValueInit("assignToStorage"); - return getValueAsString(init); -} - -StringRef Property::getConvertToAttributeCall() const { - const auto *init = def->getValueInit("convertToAttribute"); - return getValueAsString(init); -} - -StringRef Property::getConvertFromAttributeCall() const { - const auto *init = def->getValueInit("convertFromAttribute"); - return getValueAsString(init); -} - -StringRef Property::getReadFromMlirBytecodeCall() const { - const auto *init = def->getValueInit("readFromMlirBytecode"); - return getValueAsString(init); -} - -StringRef Property::getWriteToMlirBytecodeCall() const { - const auto *init = def->getValueInit("writeToMlirBytecode"); - return getValueAsString(init); -} - -StringRef Property::getHashPropertyCall() const { - return getValueAsString(def->getValueInit("hashProperty")); -} - -bool Property::hasDefaultValue() const { return !getDefaultValue().empty(); } - -StringRef Property::getDefaultValue() const { - const auto *init = def->getValueInit("defaultValue"); - return getValueAsString(init); -} - -const llvm::Record &Property::getDef() const { return *def; } diff --git a/mlir/test/IR/traits.mlir b/mlir/test/IR/traits.mlir --- a/mlir/test/IR/traits.mlir +++ b/mlir/test/IR/traits.mlir @@ -383,21 +383,21 @@ // ----- func.func @failedMissingOperandSizeAttr(%arg: i32) { - // expected-error @+1 {{requires dense i32 array attribute 'operand_segment_sizes'}} + // expected-error @+1 {{op operand count (4) does not match with the total size (0) specified in attribute 'operand_segment_sizes'}} "test.attr_sized_operands"(%arg, %arg, %arg, %arg) : (i32, i32, i32, i32) -> () } // ----- func.func @failedOperandSizeAttrWrongType(%arg: i32) { - // expected-error @+1 {{attribute 'operand_segment_sizes' failed to satisfy constraint: i32 dense array attribute}} + // expected-error @+1 {{op operand count (4) does not match with the total size (0) specified in attribute 'operand_segment_sizes'}} "test.attr_sized_operands"(%arg, %arg, %arg, %arg) {operand_segment_sizes = 10} : (i32, i32, i32, i32) -> () } // ----- func.func @failedOperandSizeAttrWrongElementType(%arg: i32) { - // expected-error @+1 {{attribute 'operand_segment_sizes' failed to satisfy constraint: i32 dense array attribute}} + // expected-error @+1 {{op operand count (4) does not match with the total size (0) specified in attribute 'operand_segment_sizes'}} "test.attr_sized_operands"(%arg, %arg, %arg, %arg) {operand_segment_sizes = array} : (i32, i32, i32, i32) -> () } @@ -418,7 +418,7 @@ // ----- func.func @failedOperandSizeAttrWrongCount(%arg: i32) { - // expected-error @+1 {{'operand_segment_sizes' attribute for specifying operand segments must have 4 elements}} + // expected-error @+1 {{test.attr_sized_operands' op operand count (4) does not match with the total size (0) specified in attribute 'operand_segment_sizes}} "test.attr_sized_operands"(%arg, %arg, %arg, %arg) {operand_segment_sizes = array} : (i32, i32, i32, i32) -> () } @@ -433,14 +433,14 @@ // ----- func.func @failedMissingResultSizeAttr() { - // expected-error @+1 {{requires dense i32 array attribute 'result_segment_sizes'}} + // expected-error @+1 {{op result count (4) does not match with the total size (0) specified in attribute 'result_segment_sizes'}} %0:4 = "test.attr_sized_results"() : () -> (i32, i32, i32, i32) } // ----- func.func @failedResultSizeAttrWrongType() { - // expected-error @+1 {{requires dense i32 array attribute 'result_segment_sizes'}} + // expected-error @+1 {{ op result count (4) does not match with the total size (0) specified in attribute 'result_segment_sizes'}} %0:4 = "test.attr_sized_results"() {result_segment_sizes = 10} : () -> (i32, i32, i32, i32) } @@ -448,7 +448,7 @@ // ----- func.func @failedResultSizeAttrWrongElementType() { - // expected-error @+1 {{requires dense i32 array attribute 'result_segment_sizes'}} + // expected-error @+1 {{ op result count (4) does not match with the total size (0) specified in attribute 'result_segment_sizes'}} %0:4 = "test.attr_sized_results"() {result_segment_sizes = array} : () -> (i32, i32, i32, i32) } @@ -469,7 +469,7 @@ // ----- func.func @failedResultSizeAttrWrongCount() { - // expected-error @+1 {{'result_segment_sizes' attribute for specifying result segments must have 4 elements, but got 3}} + // expected-error @+1 {{ op result count (4) does not match with the total size (0) specified in attribute 'result_segment_sizes'}} %0:4 = "test.attr_sized_results"() {result_segment_sizes = array} : () -> (i32, i32, i32, i32) } diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -839,8 +839,7 @@ Variadic:$a, Variadic:$b, I32:$c, - Variadic:$d, - DenseI32ArrayAttr:$operand_segment_sizes + Variadic:$d ); } diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -115,6 +115,10 @@ assert({0} && "missing segment size attribute for op"); auto sizeAttr = ::llvm::cast<::mlir::DenseI32ArrayAttr>({0}); )"; +static const char *const adapterSegmentSizeAttrInitCodeProperties = R"( + ::llvm::ArrayRef sizeAttr = {0}; +)"; + /// The code snippet to initialize the sizes for the value range calculation. /// /// {0}: The code to get the attribute. @@ -150,6 +154,37 @@ std::next({0}, valueRange.first + valueRange.second)}; )"; +/// Read operand/result segment_size from bytecode. +static const char *const readBytecodeSegmentSize = R"( +if ($_reader.getBytecodeVersion() < /*kNativePropertiesODSSegmentSize=*/6) { + DenseI32ArrayAttr attr; + if (::mlir::failed($_reader.readAttribute(attr))) return failure(); + if (attr.size() > static_cast(sizeof($_storage) / sizeof(int32_t))) { + $_reader.emitError("Size mismatch for operand/result_segment_size"); + return failure(); + } + llvm::copy(ArrayRef(attr), $_storage); +} else { + if (::mlir::failed($_reader.readList($_storage, + [&]() -> FailureOr { + uint64_t elt; + if(failed($_reader.readVarInt(elt))) return failure(); + return elt; + }))) + return ::mlir::failure(); +} +)"; + +/// Write operand/result segment_size to bytecode. +static const char *const writeBytecodeSegmentSize = R"( +if ($_writer.getBytecodeVersion() < /*kNativePropertiesODSSegmentSize=*/6) { + $_writer.writeAttribute(DenseI32ArrayAttr::get(getContext(), $_storage)); +} else { + $_writer.writeList($_storage, + [&](int32_t elt) { $_writer.writeVarInt(elt); }); +} +)"; + /// A header for indicating code sections. /// /// {0}: Some text, or a class name. @@ -343,6 +378,9 @@ return true; if (!op.getDialect().usePropertiesForAttributes()) return false; + if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments") || + op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) + return true; return llvm::any_of(getAttrMetadata(), [](const std::pair &it) { return !it.second.constraint || @@ -350,6 +388,14 @@ }); } + std::optional &getOperandSegmentsSize() { + return operandSegmentsSize; + } + + std::optional &getResultSegmentsSize() { + return resultSegmentsSize; + } + private: // Compute the attribute metadata. void computeAttrMetadata(); @@ -361,6 +407,13 @@ // The attribute metadata, mapped by name. llvm::MapVector attrMetadata; + + // Property + std::optional operandSegmentsSize; + std::string operandSegmentsSizeStorage; + std::optional resultSegmentsSize; + std::string resultSegmentsSizeStorage; + // The number of required attributes. unsigned numRequired; }; @@ -377,18 +430,50 @@ attrMetadata.insert( {namedAttr.name, AttributeMetadata{namedAttr.name, !isOptional, attr}}); } + + auto makeProperty = [&](StringRef storageType) { + return Property( + /*storageType=*/storageType, + /*interfaceType=*/"::llvm::ArrayRef", + /*convertFromStorageCall=*/"$_storage", + /*assignToStorageCall=*/"::llvm::copy($_value, $_storage)", + /*convertToAttributeCall=*/ + "DenseI32ArrayAttr::get($_ctxt, $_storage)", + /*convertFromAttributeCall=*/ + "return convertFromAttribute($_storage, $_attr, $_diag);", + /*readFromMlirBytecodeCall=*/readBytecodeSegmentSize, + /*writeToMlirBytecodeCall=*/writeBytecodeSegmentSize, + /*hashPropertyCall=*/ + "llvm::hash_combine_range(std::begin($_storage), " + "std::end($_storage));", + /*StringRef defaultValue=*/""); + }; // Include key attributes from several traits as implicitly registered. if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { - attrMetadata.insert( - {operandSegmentAttrName, - AttributeMetadata{operandSegmentAttrName, /*isRequired=*/true, - /*attr=*/std::nullopt}}); + if (op.getDialect().usePropertiesForAttributes()) { + operandSegmentsSizeStorage = + llvm::formatv("int32_t[{0}]", op.getNumOperands()); + operandSegmentsSize = {operandSegmentAttrName, + makeProperty(operandSegmentsSizeStorage)}; + } else { + attrMetadata.insert( + {operandSegmentAttrName, AttributeMetadata{operandSegmentAttrName, + /*isRequired=*/true, + /*attr=*/std::nullopt}}); + } } if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) { - attrMetadata.insert( - {resultSegmentAttrName, - AttributeMetadata{resultSegmentAttrName, /*isRequired=*/true, - /*attr=*/std::nullopt}}); + if (op.getDialect().usePropertiesForAttributes()) { + resultSegmentsSizeStorage = + llvm::formatv("int32_t[{0}]", op.getNumResults()); + resultSegmentsSize = {resultSegmentAttrName, + makeProperty(resultSegmentsSizeStorage)}; + } else { + attrMetadata.insert( + {resultSegmentAttrName, + AttributeMetadata{resultSegmentAttrName, /*isRequired=*/true, + /*attr=*/std::nullopt}}); + } } // Store the metadata in sorted order. @@ -660,14 +745,17 @@ // Verify a few traits first so that we can use getODSOperands() and // getODSResults() in the rest of the verifier. auto &op = emitHelper.getOp(); - if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { - body << formatv(checkAttrSizedValueSegmentsCode, operandSegmentAttrName, - op.getNumOperands(), "operand", - emitHelper.emitErrorPrefix()); - } - if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) { - body << formatv(checkAttrSizedValueSegmentsCode, resultSegmentAttrName, - op.getNumResults(), "result", emitHelper.emitErrorPrefix()); + if (!op.getDialect().usePropertiesForAttributes()) { + if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { + body << formatv(checkAttrSizedValueSegmentsCode, operandSegmentAttrName, + op.getNumOperands(), "operand", + emitHelper.emitErrorPrefix()); + } + if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) { + body << formatv(checkAttrSizedValueSegmentsCode, resultSegmentAttrName, + op.getNumResults(), "result", + emitHelper.emitErrorPrefix()); + } } } @@ -964,14 +1052,16 @@ void OpEmitter::genAttrNameGetters() { const llvm::MapVector &attributes = emitHelper.getAttrMetadata(); - + bool hasOperandSegmentsSize = + op.getDialect().usePropertiesForAttributes() && + op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments"); // Emit the getAttributeNames method. { auto *method = opClass.addStaticInlineMethod( "::llvm::ArrayRef<::llvm::StringRef>", "getAttributeNames"); ERROR_IF_PRUNED(method, "getAttributeNames", op); auto &body = method->body(); - if (attributes.empty()) { + if (!hasOperandSegmentsSize && attributes.empty()) { body << " return {};"; // Nothing else to do if there are no registered attributes. Exit early. return; @@ -981,6 +1071,11 @@ [&](StringRef attrName) { body << "::llvm::StringRef(\"" << attrName << "\")"; }); + if (hasOperandSegmentsSize) { + if (!attributes.empty()) + body << ", "; + body << "::llvm::StringRef(\"" << operandSegmentAttrName << "\")"; + } body << "};\n return ::llvm::ArrayRef(attrNames);"; } @@ -1033,6 +1128,26 @@ "name, " + Twine(index)); } } + if (hasOperandSegmentsSize) { + std::string name = op.getGetterName(operandSegmentAttrName); + std::string methodName = name + "AttrName"; + // Generate the non-static variant. + { + auto *method = opClass.addInlineMethod("::mlir::StringAttr", methodName); + ERROR_IF_PRUNED(method, methodName, op); + method->body() + << " return (*this)->getName().getAttributeNames().back();"; + } + + // Generate the static variant. + { + auto *method = opClass.addStaticInlineMethod( + "::mlir::StringAttr", methodName, + MethodParameter("::mlir::OperationName", "name")); + ERROR_IF_PRUNED(method, methodName, op); + method->body() << " return name.getAttributeNames().back();"; + } + } } // Emit the getter for an attribute with the return type specified. @@ -1080,6 +1195,10 @@ } for (const NamedProperty &prop : op.getProperties()) attrOrProperties.push_back(&prop); + if (emitHelper.getOperandSegmentsSize()) + attrOrProperties.push_back(&emitHelper.getOperandSegmentsSize().value()); + if (emitHelper.getResultSegmentsSize()) + attrOrProperties.push_back(&emitHelper.getResultSegmentsSize().value()); if (attrOrProperties.empty()) return; auto &setPropMethod = @@ -1104,6 +1223,7 @@ auto &getInherentAttrMethod = opClass .addStaticMethod("std::optional", "getInherentAttr", + MethodParameter("::mlir::MLIRContext *", "ctx"), MethodParameter("const Properties &", "prop"), MethodParameter("llvm::StringRef", "name")) ->body(); @@ -1117,6 +1237,7 @@ auto &populateInherentAttrsMethod = opClass .addStaticMethod("void", "populateInherentAttrs", + MethodParameter("::mlir::MLIRContext *", "ctx"), MethodParameter("const Properties &", "prop"), MethodParameter("::mlir::NamedAttrList &", "attrs")) ->body(); @@ -1318,6 +1439,34 @@ << formatv(populateInherentAttrsMethodFmt, name); continue; } + // The ODS segment size property is "special": we expose it as an attribute + // even though it is a native property. + const auto *namedProperty = cast(attrOrProp); + StringRef name = namedProperty->name; + if (name != operandSegmentAttrName && name != resultSegmentAttrName) + continue; + auto &prop = namedProperty->prop; + FmtContext fctx; + fctx.addSubst("_ctxt", "ctx"); + fctx.addSubst("_storage", Twine("prop.") + name); + getInherentAttrMethod << formatv(" if (name == \"{0}\") return ", name); + getInherentAttrMethod << tgfmt(prop.getConvertToAttributeCall(), &fctx) + << ";\n"; + + setInherentAttrMethod << formatv(R"decl( + if (name == "{0}") {{ + auto arrAttr = dyn_cast_or_null(value); + if (!arrAttr) return; + if (arrAttr.size() != sizeof(prop.{0}) / sizeof(int32_t)) + return; + llvm::copy(arrAttr.asArrayRef(), prop.{0}); + return; + } +)decl", + name); + populateInherentAttrsMethod + << formatv(" attrs.append(\"{0}\", {1});\n", name, + tgfmt(prop.getConvertToAttributeCall(), &fctx)); } getInherentAttrMethod << " return std::nullopt;\n"; @@ -1815,8 +1964,14 @@ // array. std::string attrSizeInitCode; if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { - attrSizeInitCode = formatv(opSegmentSizeAttrInitCode, - emitHelper.getAttr(operandSegmentAttrName)); + if (op.getDialect().usePropertiesForAttributes()) + attrSizeInitCode = + formatv(adapterSegmentSizeAttrInitCodeProperties, + llvm::formatv("getProperties().{0}", operandSegmentAttrName)); + + else + attrSizeInitCode = formatv(opSegmentSizeAttrInitCode, + emitHelper.getAttr(operandSegmentAttrName)); } generateNamedOperandGetters( @@ -1853,7 +2008,8 @@ if (emitHelper.hasProperties()) body << formatv( ", ::mlir::MutableOperandRange::OperandSegment({0}u, " - "{getOperandSegmentSizesAttrName(), getProperties().{1}})", + "{getOperandSegmentSizesAttrName(), " + "DenseI32ArrayAttr::get(getContext(), getProperties().{1})})", i, operandSegmentAttrName); else body << formatv( @@ -1910,8 +2066,14 @@ // Build the initializer string for the result segment size attribute. std::string attrSizeInitCode; if (attrSizedResults) { - attrSizeInitCode = formatv(opSegmentSizeAttrInitCode, - emitHelper.getAttr(resultSegmentAttrName)); + if (op.getDialect().usePropertiesForAttributes()) + attrSizeInitCode = + formatv(adapterSegmentSizeAttrInitCodeProperties, + llvm::formatv("getProperties().{0}", resultSegmentAttrName)); + + else + attrSizeInitCode = formatv(opSegmentSizeAttrInitCode, + emitHelper.getAttr(resultSegmentAttrName)); } generateValueRangeStartAndEnd( @@ -2086,10 +2248,7 @@ // the length of the type ranges. if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) { if (op.getDialect().usePropertiesForAttributes()) { - body << " (" << builderOpState - << ".getOrAddProperties()." << resultSegmentAttrName - << " = \n" - " odsBuilder.getDenseI32ArrayAttr({"; + body << " llvm::copy(ArrayRef({"; } else { std::string getterName = op.getGetterName(resultSegmentAttrName); body << " " << builderOpState << ".addAttribute(" << getterName @@ -2112,7 +2271,13 @@ body << "static_cast(" << resultNames[i] << ".size())"; } }); - body << "}));\n"; + if (op.getDialect().usePropertiesForAttributes()) { + body << "}), " << builderOpState + << ".getOrAddProperties()." << resultSegmentAttrName + << ");\n"; + } else { + body << "}));\n"; + } } return; @@ -2706,17 +2871,7 @@ } // If the operation has the operand segment size attribute, add it here. - if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { - std::string sizes = op.getGetterName(operandSegmentAttrName); - if (op.getDialect().usePropertiesForAttributes()) { - body << " (" << builderOpState << ".getOrAddProperties()." - << operandSegmentAttrName << "= " - << "odsBuilder.getDenseI32ArrayAttr({"; - } else { - body << " " << builderOpState << ".addAttribute(" << sizes << "AttrName(" - << builderOpState << ".name), " - << "odsBuilder.getDenseI32ArrayAttr({"; - } + auto emitSegment = [&]() { interleaveComma(llvm::seq(0, op.getNumOperands()), body, [&](int i) { const NamedTypeConstraint &operand = op.getOperand(i); if (!operand.isVariableLength()) { @@ -2737,7 +2892,21 @@ body << "static_cast(" << getArgumentName(op, i) << ".size())"; } }); - body << "}));\n"; + }; + if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { + std::string sizes = op.getGetterName(operandSegmentAttrName); + if (op.getDialect().usePropertiesForAttributes()) { + body << " llvm::copy(ArrayRef({"; + emitSegment(); + body << "}), " << builderOpState << ".getOrAddProperties()." + << operandSegmentAttrName << ");\n"; + } else { + body << " " << builderOpState << ".addAttribute(" << sizes << "AttrName(" + << builderOpState << ".name), " + << "odsBuilder.getDenseI32ArrayAttr({"; + emitSegment(); + body << "}));\n"; + } } // Push all attributes to the result. @@ -3541,6 +3710,10 @@ } for (const NamedProperty &prop : op.getProperties()) attrOrProperties.push_back(&prop); + if (emitHelper.getOperandSegmentsSize()) + attrOrProperties.push_back(&emitHelper.getOperandSegmentsSize().value()); + if (emitHelper.getResultSegmentsSize()) + attrOrProperties.push_back(&emitHelper.getResultSegmentsSize().value()); assert(!attrOrProperties.empty()); std::string declarations = " struct Properties {\n"; llvm::raw_string_ostream os(declarations); @@ -3710,8 +3883,13 @@ std::string sizeAttrInit; if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { - sizeAttrInit = formatv(adapterSegmentSizeAttrInitCode, - emitHelper.getAttr(operandSegmentAttrName)); + if (op.getDialect().usePropertiesForAttributes()) + sizeAttrInit = + formatv(adapterSegmentSizeAttrInitCodeProperties, + llvm::formatv("getProperties().{0}", operandSegmentAttrName)); + else + sizeAttrInit = formatv(adapterSegmentSizeAttrInitCode, + emitHelper.getAttr(operandSegmentAttrName)); } generateNamedOperandGetters(op, genericAdaptor, /*genericAdaptorBase=*/&genericAdaptorBase, diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -1654,16 +1654,6 @@ MethodBody &body) { if (!allOperands) { if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { - if (op.getDialect().usePropertiesForAttributes()) { - body << formatv(" " - "result.getOrAddProperties<{0}::Properties>().operand_" - "segment_sizes = " - "(parser.getBuilder().getDenseI32ArrayAttr({{", - op.getCppClassName()); - } else { - body << " result.addAttribute(\"operand_segment_sizes\", " - << "parser.getBuilder().getDenseI32ArrayAttr({"; - } auto interleaveFn = [&](const NamedTypeConstraint &operand) { // If the operand is variadic emit the parsed size. if (operand.isVariableLength()) @@ -1671,8 +1661,19 @@ else body << "1"; }; - llvm::interleaveComma(op.getOperands(), body, interleaveFn); - body << "}));\n"; + if (op.getDialect().usePropertiesForAttributes()) { + body << "llvm::copy(ArrayRef({"; + llvm::interleaveComma(op.getOperands(), body, interleaveFn); + body << formatv("}), " + "result.getOrAddProperties<{0}::Properties>().operand_" + "segment_sizes);\n", + op.getCppClassName()); + } else { + body << " result.addAttribute(\"operand_segment_sizes\", " + << "parser.getBuilder().getDenseI32ArrayAttr({"; + llvm::interleaveComma(op.getOperands(), body, interleaveFn); + body << "}));\n"; + } } for (const NamedTypeConstraint &operand : op.getOperands()) { if (!operand.isVariadicOfVariadic()) @@ -1697,16 +1698,6 @@ if (!allResultTypes && op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) { - if (op.getDialect().usePropertiesForAttributes()) { - body << formatv( - " " - "result.getOrAddProperties<{0}::Properties>().result_segment_sizes = " - "(parser.getBuilder().getDenseI32ArrayAttr({{", - op.getCppClassName()); - } else { - body << " result.addAttribute(\"result_segment_sizes\", " - << "parser.getBuilder().getDenseI32ArrayAttr({"; - } auto interleaveFn = [&](const NamedTypeConstraint &result) { // If the result is variadic emit the parsed size. if (result.isVariableLength()) @@ -1714,8 +1705,20 @@ else body << "1"; }; - llvm::interleaveComma(op.getResults(), body, interleaveFn); - body << "}));\n"; + if (op.getDialect().usePropertiesForAttributes()) { + body << "llvm::copy(ArrayRef({"; + llvm::interleaveComma(op.getResults(), body, interleaveFn); + body << formatv( + "}), " + "result.getOrAddProperties<{0}::Properties>().result_segment_sizes" + ");\n", + op.getCppClassName()); + } else { + body << " result.addAttribute(\"result_segment_sizes\", " + << "parser.getBuilder().getDenseI32ArrayAttr({"; + llvm::interleaveComma(op.getResults(), body, interleaveFn); + body << "}));\n"; + } } } diff --git a/mlir/unittests/IR/AdaptorTest.cpp b/mlir/unittests/IR/AdaptorTest.cpp --- a/mlir/unittests/IR/AdaptorTest.cpp +++ b/mlir/unittests/IR/AdaptorTest.cpp @@ -39,7 +39,7 @@ // value from the value 0. SmallVector> v = {0, 4}; OIListSimple::Properties prop; - prop.operand_segment_sizes = builder.getDenseI32ArrayAttr({1, 0, 1}); + llvm::copy(ArrayRef{1, 0, 1}, prop.operand_segment_sizes); OIListSimple::GenericAdaptor>> d(v, {}, prop, {}); EXPECT_EQ(d.getArg0(), 0); diff --git a/mlir/unittests/IR/OpPropertiesTest.cpp b/mlir/unittests/IR/OpPropertiesTest.cpp --- a/mlir/unittests/IR/OpPropertiesTest.cpp +++ b/mlir/unittests/IR/OpPropertiesTest.cpp @@ -115,13 +115,15 @@ // This alias is the only definition needed for enabling "properties" for this // operation. using Properties = TestProperties; - static std::optional getInherentAttr(const Properties &prop, + static std::optional getInherentAttr(MLIRContext *context, + const Properties &prop, StringRef name) { return std::nullopt; } static void setInherentAttr(Properties &prop, StringRef name, mlir::Attribute value) {} - static void populateInherentAttrs(const Properties &prop, + static void populateInherentAttrs(MLIRContext *context, + const Properties &prop, NamedAttrList &attrs) {} static LogicalResult verifyInherentAttrs(OperationName opName, NamedAttrList &attrs,