diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -428,6 +428,23 @@ return findAttrUnsorted(first, last, name); } +/// Get an attribute from a sorted range of named attributes. Returns null if +/// the attribute was not found. +template +Attribute getAttrFromSortedRange(IteratorT first, IteratorT last, NameT name) { + std::pair result = findAttrSorted(first, last, name); + return result.second ? result.first->getValue() : Attribute(); +} + +/// Get an attribute from a sorted range of named attributes. Returns None if +/// the attribute was not found. +template +Optional +getNamedAttrFromSortedRange(IteratorT first, IteratorT last, NameT name) { + std::pair result = findAttrSorted(first, last, name); + return result.second ? *result.first : Optional(); +} + } // namespace impl //===----------------------------------------------------------------------===// @@ -447,7 +464,7 @@ NamedAttrList() : dictionarySorted({}, true) {} NamedAttrList(ArrayRef attributes); NamedAttrList(DictionaryAttr attributes); - NamedAttrList(const_iterator in_start, const_iterator in_end); + NamedAttrList(const_iterator inStart, const_iterator inEnd); bool operator!=(const NamedAttrList &other) const { return !(*this == other); @@ -478,15 +495,15 @@ typename = std::enable_if_t::iterator_category, std::input_iterator_tag>::value>> - void append(IteratorT in_start, IteratorT in_end) { + void append(IteratorT inStart, IteratorT inEnd) { // TODO: expand to handle case where values appended are in order & after // end of current list. dictionarySorted.setPointerAndInt(nullptr, false); - attrs.append(in_start, in_end); + attrs.append(inStart, inEnd); } /// Replaces the attributes with new list of attributes. - void assign(const_iterator in_start, const_iterator in_end); + void assign(const_iterator inStart, const_iterator inEnd); /// Replaces the attributes with new list of attributes. void assign(ArrayRef range) { diff --git a/mlir/test/Dialect/LLVMIR/global.mlir b/mlir/test/Dialect/LLVMIR/global.mlir --- a/mlir/test/Dialect/LLVMIR/global.mlir +++ b/mlir/test/Dialect/LLVMIR/global.mlir @@ -81,7 +81,7 @@ // ----- // expected-error @+1 {{op requires attribute 'sym_name'}} -"llvm.mlir.global"() ({}) {type = i64, constant, global_type = i64, value = 42 : i64} : () -> () +"llvm.mlir.global"() ({}) {linkage = "private", type = i64, constant, global_type = i64, value = 42 : i64} : () -> () // ----- diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir --- a/mlir/test/IR/attribute.mlir +++ b/mlir/test/IR/attribute.mlir @@ -230,7 +230,8 @@ "test.int_attrs"() { any_i32_attr = 5.0 : f32, si32_attr = 7 : si32, - ui32_attr = 6 : ui32 + ui32_attr = 6 : ui32, + index_attr = 1 : index } : () -> () return } diff --git a/mlir/test/mlir-tblgen/constraint-unique.td b/mlir/test/mlir-tblgen/constraint-unique.td --- a/mlir/test/mlir-tblgen/constraint-unique.td +++ b/mlir/test/mlir-tblgen/constraint-unique.td @@ -116,7 +116,7 @@ /// Test that the uniqued constraints are being used. // CHECK-LABEL: OpA::verify -// CHECK: auto [[$B_ATTR:.*b]] = (*this)->getAttr(bAttrName()); +// CHECK: ::mlir::Attribute [[$B_ATTR:.*b]]; // CHECK: if (::mlir::failed([[$A_ATTR_CONSTRAINT]](*this, [[$B_ATTR]], "b"))) // CHECK-NEXT: return ::mlir::failure(); // CHECK: auto [[$A_VALUE_GROUP:.*]] = getODSOperands(0); @@ -137,7 +137,7 @@ /// Test that the op with the same predicates but different with descriptions /// uses the different constraints. // CHECK-LABEL: OpC::verify -// CHECK: auto [[$B_ATTR:.*b]] = (*this)->getAttr(bAttrName()); +// CHECK: ::mlir::Attribute [[$B_ATTR:.*b]]; // CHECK: if (::mlir::failed([[$O_ATTR_CONSTRAINT]](*this, [[$B_ATTR]], "b"))) // CHECK-NEXT: return ::mlir::failure(); // CHECK: auto [[$A_VALUE_GROUP:.*]] = getODSOperands(0); diff --git a/mlir/test/mlir-tblgen/op-attribute.td b/mlir/test/mlir-tblgen/op-attribute.td --- a/mlir/test/mlir-tblgen/op-attribute.td +++ b/mlir/test/mlir-tblgen/op-attribute.td @@ -65,29 +65,40 @@ // --- // DEF: ::mlir::LogicalResult AOpAdaptor::verify -// DEF: auto tblgen_aAttr = odsAttrs.get("aAttr"); -// DEF-NEXT: if (!tblgen_aAttr) -// DEF-NEXT: return emitError(loc, "'test.a_op' op ""requires attribute 'aAttr'"); +// DEF: ::mlir::Attribute tblgen_aAttr; +// DEF-NEXT: while (true) { +// DEF-NEXT: if (namedAttrIt == namedAttrRange.end()) +// DEF-NEXT: return emitError(loc, "'test.a_op' op ""requires attribute 'aAttr'"); +// DEF-NEXT: if (namedAttrIt->getName() == AOp::aAttrAttrName(*odsOpName)) { +// DEF-NEXT: tblgen_aAttr = namedAttrIt->getValue(); +// DEF-NEXT: break; +// DEF: ::mlir::Attribute tblgen_bAttr; +// DEF-NEXT: ::mlir::Attribute tblgen_cAttr; +// DEF-NEXT: while (true) { +// DEF-NEXT: if (namedAttrIt == namedAttrRange.end()) +// DEF-NEXT: break; +// DEF: if (namedAttrIt->getName() == AOp::bAttrAttrName(*odsOpName)) +// DEF-NEXT: tblgen_bAttr = namedAttrIt->getValue(); +// DEF: if (namedAttrIt->getName() == AOp::cAttrAttrName(*odsOpName)) +// DEF-NEXT: tblgen_cAttr = namedAttrIt->getValue(); // DEF: if (tblgen_aAttr && !((some-condition))) // DEF-NEXT: return emitError(loc, "'test.a_op' op ""attribute 'aAttr' failed to satisfy constraint: some attribute kind"); -// DEF: auto tblgen_bAttr = odsAttrs.get("bAttr"); -// DEF-NEXT: if (tblgen_bAttr && !((some-condition))) +// DEF: if (tblgen_bAttr && !((some-condition))) // DEF-NEXT: return emitError(loc, "'test.a_op' op ""attribute 'bAttr' failed to satisfy constraint: some attribute kind"); -// DEF: auto tblgen_cAttr = odsAttrs.get("cAttr"); -// DEF-NEXT: if (tblgen_cAttr && !((some-condition))) +// DEF: if (tblgen_cAttr && !((some-condition))) // DEF-NEXT: return emitError(loc, "'test.a_op' op ""attribute 'cAttr' failed to satisfy constraint: some attribute kind"); // Test getter methods // --- // DEF: some-attr-kind AOp::aAttrAttr() -// DEF-NEXT: (*this)->getAttr(aAttrAttrName()).cast() +// DEF-NEXT: ::mlir::impl::getAttrFromSortedRange((*this)->getAttrs().begin() + 0, (*this)->getAttrs().end() - 0, aAttrAttrName()).cast() // DEF: some-return-type AOp::aAttr() { // DEF-NEXT: auto attr = aAttrAttr() // DEF-NEXT: return attr.some-convert-from-storage(); // DEF: some-attr-kind AOp::bAttrAttr() -// DEF-NEXT: return (*this)->getAttr(bAttrAttrName()).dyn_cast_or_null() +// DEF-NEXT: ::mlir::impl::getAttrFromSortedRange((*this)->getAttrs().begin() + 1, (*this)->getAttrs().end() - 0, bAttrAttrName()).dyn_cast_or_null() // DEF: some-return-type AOp::bAttr() { // DEF-NEXT: auto attr = bAttrAttr(); // DEF-NEXT: if (!attr) @@ -95,7 +106,7 @@ // DEF-NEXT: return attr.some-convert-from-storage(); // DEF: some-attr-kind AOp::cAttrAttr() -// DEF-NEXT: return (*this)->getAttr(cAttrAttrName()).dyn_cast_or_null() +// DEF-NEXT: ::mlir::impl::getAttrFromSortedRange((*this)->getAttrs().begin() + 1, (*this)->getAttrs().end() - 0, cAttrAttrName()).dyn_cast_or_null() // DEF: ::llvm::Optional AOp::cAttr() { // DEF-NEXT: auto attr = cAttrAttr() // DEF-NEXT: return attr ? ::llvm::Optional(attr.some-convert-from-storage()) : (::llvm::None); @@ -179,29 +190,29 @@ // --- // DEF: ::mlir::LogicalResult AgetOpAdaptor::verify -// DEF: auto tblgen_aAttr = odsAttrs.get("aAttr"); -// DEF-NEXT: if (!tblgen_aAttr) -// DEF-NEXT. return emitError(loc, "'test2.a_get_op' op ""requires attribute 'aAttr'"); +// DEF: ::mlir::Attribute tblgen_aAttr; +// DEF-NEXT: while (true) +// DEF: ::mlir::Attribute tblgen_bAttr; +// DEF-NEXT: ::mlir::Attribute tblgen_cAttr; +// DEF-NEXT: while (true) // DEF: if (tblgen_aAttr && !((some-condition))) // DEF-NEXT: return emitError(loc, "'test2.a_get_op' op ""attribute 'aAttr' failed to satisfy constraint: some attribute kind"); -// DEF: auto tblgen_bAttr = odsAttrs.get("bAttr"); -// DEF-NEXT: if (tblgen_bAttr && !((some-condition))) +// DEF: if (tblgen_bAttr && !((some-condition))) // DEF-NEXT: return emitError(loc, "'test2.a_get_op' op ""attribute 'bAttr' failed to satisfy constraint: some attribute kind"); -// DEF: auto tblgen_cAttr = odsAttrs.get("cAttr"); -// DEF-NEXT: if (tblgen_cAttr && !((some-condition))) +// DEF: if (tblgen_cAttr && !((some-condition))) // DEF-NEXT: return emitError(loc, "'test2.a_get_op' op ""attribute 'cAttr' failed to satisfy constraint: some attribute kind"); // Test getter methods // --- // DEF: some-attr-kind AgetOp::getAAttrAttr() -// DEF-NEXT: (*this)->getAttr(getAAttrAttrName()).cast() +// DEF-NEXT: ::mlir::impl::getAttrFromSortedRange({{.*}}).cast() // DEF: some-return-type AgetOp::getAAttr() { // DEF-NEXT: auto attr = getAAttrAttr() // DEF-NEXT: return attr.some-convert-from-storage(); // DEF: some-attr-kind AgetOp::getBAttrAttr() -// DEF-NEXT: return (*this)->getAttr(getBAttrAttrName()).dyn_cast_or_null() +// DEF-NEXT: return ::mlir::impl::getAttrFromSortedRange({{.*}}).dyn_cast_or_null() // DEF: some-return-type AgetOp::getBAttr() { // DEF-NEXT: auto attr = getBAttrAttr(); // DEF-NEXT: if (!attr) @@ -209,7 +220,7 @@ // DEF-NEXT: return attr.some-convert-from-storage(); // DEF: some-attr-kind AgetOp::getCAttrAttr() -// DEF-NEXT: return (*this)->getAttr(getCAttrAttrName()).dyn_cast_or_null() +// DEF-NEXT: return ::mlir::impl::getAttrFromSortedRange({{.*}}).dyn_cast_or_null() // DEF: ::llvm::Optional AgetOp::getCAttr() { // DEF-NEXT: auto attr = getCAttrAttr() // DEF-NEXT: return attr ? ::llvm::Optional(attr.some-convert-from-storage()) : (::llvm::None); diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -26,6 +26,7 @@ #include "llvm/ADT/Sequence.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringSet.h" +#include "llvm/ADT/StringSwitch.h" #include "llvm/Support/Debug.h" #include "llvm/Support/Signals.h" #include "llvm/TableGen/Error.h" @@ -43,10 +44,22 @@ static const char *const odsBuilder = "odsBuilder"; static const char *const builderOpState = "odsState"; -/// Code for an Op to lookup an attribute. Uses cached identifiers. +/// The names of the implicit attributes that contain variadic operand and +/// result segment sizes. +static const char *const operandSegmentAttrName = "operand_segment_sizes"; +static const char *const resultSegmentAttrName = "result_segment_sizes"; + +/// Code for an Op to lookup an attribute. Uses cached identifiers and subrange +/// lookup. /// -/// {0}: The attribute's getter name. -static const char *const opGetAttr = "(*this)->getAttr({0}AttrName())"; +/// {0}: Code snippet to get the attribute's name or identifier. +/// {1}: The lower bound on the sorted subrange. +/// {2}: The upper bound on the sorted subrange. +/// {3}: Code snippet to get the array of named attributes. +/// {4}: "Named" to get the named attribute. +static const char *const subrangeGetAttr = + "::mlir::impl::get{4}AttrFromSortedRange({3}.begin() + {1}, {3}.end() - " + "{2}, {0})"; /// The logic to calculate the actual value range for a declared operand/result /// of an op with variadic operands/results. Note that this logic is not for @@ -80,16 +93,6 @@ /// of an op with variadic operands/results. Note that this logic is assumes /// the op has an attribute specifying the size of each operand/result segment /// (variadic or not). -/// -/// {0}: The name of the attribute specifying the segment sizes. -static const char *const adapterSegmentSizeAttrInitCode = R"( - assert(odsAttrs && "missing segment size attribute for op"); - auto sizeAttr = odsAttrs.get("{0}").cast<::mlir::DenseIntElementsAttr>(); -)"; -static const char *const opSegmentSizeAttrInitCode = R"( - auto sizeAttr = - (*this)->getAttr({0}AttrName()).cast<::mlir::DenseIntElementsAttr>(); -)"; static const char *const attrSizedSegmentValueRangeCalcCode = R"( const uint32_t *sizeAttrValueIt = &*sizeAttr.value_begin(); if (sizeAttr.isSplat()) @@ -100,6 +103,19 @@ start += sizeAttrValueIt[i]; return {start, sizeAttrValueIt[index]}; )"; +/// The code snippet to initialize the sizes for the value range calculation. +/// +/// {0}: The code to get the attribute. +static const char *const adapterSegmentSizeAttrInitCode = R"( + assert(odsAttrs && "missing segment size attribute for op"); + auto sizeAttr = {0}.cast<::mlir::DenseIntElementsAttr>(); +)"; +/// The code snippet to initialize the sizes for the value range calculation. +/// +/// {0}: The code to get the attribute. +static const char *const opSegmentSizeAttrInitCode = R"( + auto sizeAttr = {0}.cast<::mlir::DenseIntElementsAttr>(); +)"; /// The logic to calculate the actual value range for a declared operand /// of an op with variadic of variadic operands within the OpAdaptor. @@ -179,11 +195,31 @@ } namespace { +/// Metadata on a registered attribute. Given that attributes are stored in +/// sorted order on operations, we can use information from ODS to deduce the +/// number of required attributes less and and greater than each attribute, +/// allowing us to search only a subrange of the attributes in ODS-generated +/// getters. +struct AttributeMetadata { + /// The attribute name. + StringRef attrName; + /// Whether the attribute is required. + bool isRequired; + /// The ODS attribute constraint. Not present for implicit attributes. + Optional constraint; + /// The number of required attributes less than this attribute. + unsigned lowerBound; + /// The number of required attributes greater than this attribute. + unsigned upperBound; +}; + /// Helper class to select between OpAdaptor and Op code templates. class OpOrAdaptorHelper { public: OpOrAdaptorHelper(const Operator &op, bool emitForOp) - : op(op), emitForOp(emitForOp) {} + : op(op), emitForOp(emitForOp) { + computeAttrMetadata(); + } /// Object that wraps a functor in a stream operator for interop with /// llvm::formatv. @@ -208,14 +244,31 @@ }; // Generate code for getting an attribute. - Formatter getAttr(StringRef attrName) const { + Formatter getAttr(StringRef attrName, bool isNamed = false) const { + assert(attrMetadata.count(attrName) && "expected attribute metadata"); + return [this, attrName, isNamed](raw_ostream &os) -> raw_ostream & { + const AttributeMetadata &attr = attrMetadata.find(attrName)->second; + return os << formatv(subrangeGetAttr, getAttrName(attrName), + attr.lowerBound, attr.upperBound, getAttrRange(), + isNamed ? "Named" : ""); + }; + } + + // Generate code for getting the name of an attribute. + Formatter getAttrName(StringRef attrName) const { return [this, attrName](raw_ostream &os) -> raw_ostream & { - if (!emitForOp) - return os << formatv("odsAttrs.get(\"{0}\")", attrName); - return os << formatv(opGetAttr, op.getGetterName(attrName)); + if (emitForOp) + return os << op.getGetterName(attrName) << "AttrName()"; + return os << formatv("{0}::{1}AttrName(*odsOpName)", op.getCppClassName(), + op.getGetterName(attrName)); }; } + // Get the code snippet for getting the named attribute range. + StringRef getAttrRange() const { + return emitForOp ? "(*this)->getAttrs()" : "odsAttrs"; + } + // Get the prefix code for emitting an error. Formatter emitErrorPrefix() const { return [this](raw_ostream &os) -> raw_ostream & { @@ -254,14 +307,74 @@ // Return the ODS operation wrapper. const Operator &getOp() const { return op; } + // Get the attribute metadata sorted by name. + const llvm::MapVector &getAttrMetadata() const { + return attrMetadata; + } + private: + // Compute the attribute metadata. + void computeAttrMetadata(); + // The operation ODS wrapper. const Operator &op; // True if code is being generate for an op. False for an adaptor. const bool emitForOp; + + // The attribute metadata, mapped by name. + llvm::MapVector attrMetadata; + // The number of required attributes. + unsigned numRequired; }; + } // namespace +void OpOrAdaptorHelper::computeAttrMetadata() { + // Enumerate the attribute names of this op, ensuring the attribute names are + // unique in case implicit attributes are explicitly registered. + for (const NamedAttribute &namedAttr : op.getAttributes()) { + Attribute attr = namedAttr.attr; + bool isOptional = + attr.hasDefaultValue() || attr.isOptional() || attr.isDerivedAttr(); + attrMetadata.insert( + {namedAttr.name, AttributeMetadata{namedAttr.name, !isOptional, attr}}); + } + // Include key attributes from several traits as implicitly registered. + if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { + attrMetadata.insert( + {operandSegmentAttrName, + AttributeMetadata{operandSegmentAttrName, /*isRequired=*/true, + /*attr=*/llvm::None}}); + } + if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) { + attrMetadata.insert( + {resultSegmentAttrName, + AttributeMetadata{resultSegmentAttrName, /*isRequired=*/true, + /*attr=*/llvm::None}}); + } + + // Store the metadata in sorted order. + SmallVector sortedAttrMetadata = + llvm::to_vector(llvm::make_second_range(attrMetadata.takeVector())); + llvm::sort(sortedAttrMetadata, + [](const AttributeMetadata &lhs, const AttributeMetadata &rhs) { + return lhs.attrName < rhs.attrName; + }); + + // Compute the subrange bounds for each attribute. + numRequired = 0; + for (AttributeMetadata &attr : sortedAttrMetadata) { + attr.lowerBound = numRequired; + numRequired += attr.isRequired; + }; + for (AttributeMetadata &attr : sortedAttrMetadata) + attr.upperBound = numRequired - attr.lowerBound - attr.isRequired; + + // Store the results back into the map. + for (const AttributeMetadata &attr : sortedAttrMetadata) + attrMetadata.insert({attr.attrName, attr}); +} + //===----------------------------------------------------------------------===// // Op emitter //===----------------------------------------------------------------------===// @@ -438,7 +551,7 @@ const Record &def; // The wrapper operator class for querying information from this op. - Operator op; + const Operator &op; // The C++ code builder for this op OpClass opClass; @@ -448,6 +561,9 @@ // The emitter containing all of the locally emitted verification functions. const StaticVerifierFunctionEmitter &staticVerifierEmitter; + + // Helper for emitting op code. + OpOrAdaptorHelper emitHelper; }; } // namespace @@ -476,20 +592,59 @@ } } +/// Generate verification on native traits requiring attributes. +static void genNativeTraitAttrVerifier(MethodBody &body, + const OpOrAdaptorHelper &emitHelper) { + // Check that the variadic segment sizes attribute exists and contains the + // expected number of elements. + // + // {0}: Attribute name. + // {1}: Expected number of elements. + // {2}: "operand" or "result". + // {3}: Emit error prefix. + const char *const checkAttrSizedValueSegmentsCode = R"( + { + auto sizeAttr = tblgen_{0}.cast<::mlir::DenseIntElementsAttr>(); + auto numElements = + sizeAttr.getType().cast<::mlir::ShapedType>().getNumElements(); + if (numElements != {1}) + return {3}"'{0}' attribute for specifying {2} segments must have {1} " + "elements, but got ") << numElements; + } + )"; + + // Verify a few traits first so that we can use getODSOperands() and + // getODSResults() in the rest of the verifier. + auto &op = emitHelper.getOp(); + if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { + body << formatv(checkAttrSizedValueSegmentsCode, operandSegmentAttrName, + op.getNumOperands(), "operand", + emitHelper.emitErrorPrefix()); + } + if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) { + body << formatv(checkAttrSizedValueSegmentsCode, resultSegmentAttrName, + op.getNumResults(), "result", emitHelper.emitErrorPrefix()); + } +} + // Generate attribute verification. If an op instance is not available, then // attribute checks that require one will not be emitted. +// +// Attribute verification is performed as follows: +// +// 1. Verify that all required attributes are present in sorted order. This +// ensures that we can use subrange lookup even with potentially missing +// attributes. +// 2. Verify native trait attributes so that other attributes may call methods +// that depend on the validity of these attributes, e.g. segment size attributes +// and operand or result getters. +// 3. Verify the constraints on all present attributes. static void genAttributeVerifier( const OpOrAdaptorHelper &emitHelper, FmtContext &ctx, MethodBody &body, const StaticVerifierFunctionEmitter &staticVerifierEmitter) { - // Check that a required attribute exists. - // - // {0}: Attribute variable name. - // {1}: Emit error prefix. - // {2}: Attribute name. - const char *const verifyRequiredAttr = R"( - if (!{0}) - return {1}"requires attribute '{2}'"); -)"; + if (emitHelper.getAttrMetadata().empty()) + return; + // Verify the attribute if it is present. This assumes that default values // are valid. This code snippet pastes the condition inline. // @@ -501,8 +656,8 @@ // {3}: Attribute name. // {4}: Attribute/constraint description. const char *const verifyAttrInline = R"( - if ({0} && !({1})) - return {2}"attribute '{3}' failed to satisfy constraint: {4}"); + if ({0} && !({1})) + return {2}"attribute '{3}' failed to satisfy constraint: {4}"); )"; // Verify the attribute using a uniqued constraint. Can only be used within // the context of an op. @@ -511,54 +666,128 @@ // {1}: Attribute variable name. // {2}: Attribute name. const char *const verifyAttrUnique = R"( - if (::mlir::failed({0}(*this, {1}, "{2}"))) - return ::mlir::failure(); + if (::mlir::failed({0}(*this, {1}, "{2}"))) + return ::mlir::failure(); )"; - for (const auto &namedAttr : emitHelper.getOp().getAttributes()) { - const auto &attr = namedAttr.attr; - StringRef attrName = namedAttr.name; + // Traverse the array until the required attribute is found. Return an error + // if the traversal reached the end. + // + // {0}: Code to get the name of the attribute. + // {1}: The emit error prefix. + // {2}: The name of the attribute. + const char *const findRequiredAttr = R"(while (true) {{ + if (namedAttrIt == namedAttrRange.end()) + return {1}"requires attribute '{2}'"); + if (namedAttrIt->getName() == {0}) {{ + tblgen_{2} = namedAttrIt->getValue(); + break; + })"; + + // Emit a check to see if the iteration has encountered an optional attribute. + // + // {0}: Code to get the name of the attribute. + // {1}: The name of the attribute. + const char *const checkOptionalAttr = R"( + else if (namedAttrIt->getName() == {0}) {{ + tblgen_{1} = namedAttrIt->getValue(); + })"; + + // Emit the start of the loop for checking trailing attributes. + const char *const checkTrailingAttrs = R"(while (true) { + if (namedAttrIt == namedAttrRange.end()) { + break; + })"; + + // Return true if a verifier can be emitted for the attribute: it is not a + // derived attribute, it has a predicate, its condition is not empty, and, for + // adaptors, the condition does not reference the op. + const auto canEmitVerifier = [&](Attribute attr) { if (attr.isDerivedAttr()) - continue; + return false; + Pred pred = attr.getPredicate(); + if (pred.isNull()) + return false; + std::string condition = pred.getCondition(); + return !condition.empty() && (!StringRef(condition).contains("$_op") || + emitHelper.isEmittingForOp()); + }; - bool allowMissingAttr = attr.hasDefaultValue() || attr.isOptional(); - auto attrPred = attr.getPredicate(); - std::string condition = attrPred.isNull() ? "" : attrPred.getCondition(); - // If the attribute's condition needs an op but none is available, then the - // condition cannot be emitted. - bool canEmitCondition = - !condition.empty() && (!StringRef(condition).contains("$_op") || - emitHelper.isEmittingForOp()); - - // Prefix with `tblgen_` to avoid hiding the attribute accessor. - std::string varName = (tblgenNamePrefix + attrName).str(); - - // If the attribute is not required and we cannot emit the condition, then - // there is nothing to be done. - if (allowMissingAttr && !canEmitCondition) - continue; + // Emit the verifier for the attribute. + const auto emitVerifier = [&](Attribute attr, StringRef attrName, + StringRef varName) { + std::string condition = attr.getPredicate().getCondition(); - body << formatv(" {\n auto {0} = {1};", varName, - emitHelper.getAttr(attrName)); + Optional constraintFn; + if (emitHelper.isEmittingForOp() && + (constraintFn = staticVerifierEmitter.getAttrConstraintFn(attr))) { + body << formatv(verifyAttrUnique, *constraintFn, varName, attrName); + } else { + body << formatv(verifyAttrInline, varName, + tgfmt(condition, &ctx.withSelf(varName)), + emitHelper.emitErrorPrefix(), attrName, + escapeString(attr.getSummary())); + } + }; - if (!allowMissingAttr) { - body << formatv(verifyRequiredAttr, varName, emitHelper.emitErrorPrefix(), - attrName); + // Prefix variables with `tblgen_` to avoid hiding the attribute accessor. + const auto getVarName = [&](StringRef attrName) { + return (tblgenNamePrefix + attrName).str(); + }; + + body.indent() << formatv("auto namedAttrRange = {0};\n", + emitHelper.getAttrRange()); + body << "auto namedAttrIt = namedAttrRange.begin();\n"; + + // Iterate over the attributes in sorted order. Keep track of the optional + // attributes that may be encountered along the way. + SmallVector optionalAttrs; + for (const std::pair &it : + emitHelper.getAttrMetadata()) { + const AttributeMetadata &metadata = it.second; + if (!metadata.isRequired) { + optionalAttrs.push_back(&metadata); + continue; } - if (canEmitCondition) { - Optional constraintFn; - if (emitHelper.isEmittingForOp() && - (constraintFn = staticVerifierEmitter.getAttrConstraintFn(attr))) { - body << formatv(verifyAttrUnique, *constraintFn, varName, attrName); - } else { - body << formatv(verifyAttrInline, varName, - tgfmt(condition, &ctx.withSelf(varName)), - emitHelper.emitErrorPrefix(), attrName, - escapeString(attr.getSummary())); - } + + body << formatv("::mlir::Attribute {0};\n", getVarName(it.first)); + for (const AttributeMetadata *optional : optionalAttrs) { + body << formatv("::mlir::Attribute {0};\n", + getVarName(optional->attrName)); + } + body << formatv(findRequiredAttr, emitHelper.getAttrName(it.first), + emitHelper.emitErrorPrefix(), it.first); + for (const AttributeMetadata *optional : optionalAttrs) { + body << formatv(checkOptionalAttr, + emitHelper.getAttrName(optional->attrName), + optional->attrName); + } + body << "\n ++namedAttrIt;\n}\n"; + optionalAttrs.clear(); + } + // Get trailing optional attributes. + if (!optionalAttrs.empty()) { + for (const AttributeMetadata *optional : optionalAttrs) { + body << formatv("::mlir::Attribute {0};\n", + getVarName(optional->attrName)); + } + body << checkTrailingAttrs; + for (const AttributeMetadata *optional : optionalAttrs) { + body << formatv(checkOptionalAttr, + emitHelper.getAttrName(optional->attrName), + optional->attrName); } - body << " }\n"; + body << "\n ++namedAttrIt;\n}\n"; } + body.unindent(); + + // Emit the checks for segment attributes first so that the other constraints + // can call operand and result getters. + genNativeTraitAttrVerifier(body, emitHelper); + + for (const auto &namedAttr : emitHelper.getOp().getAttributes()) + if (canEmitVerifier(namedAttr.attr)) + emitVerifier(namedAttr.attr, namedAttr.name, getVarName(namedAttr.name)); } /// Op extra class definitions have a `$cppClass` substitution that is to be @@ -573,7 +802,8 @@ : def(op.getDef()), op(op), opClass(op.getCppClassName(), op.getExtraClassDeclaration(), formatExtraDefinitions(op)), - staticVerifierEmitter(staticVerifierEmitter) { + staticVerifierEmitter(staticVerifierEmitter), + emitHelper(op, /*emitForOp=*/true) { verifyCtx.withOp("(*this->getOperation())"); verifyCtx.addSubst("_ctxt", "this->getOperation()->getContext()"); @@ -635,26 +865,12 @@ op.getOperationName() + " (from line " + Twine(line) + ")"); } + #define ERROR_IF_PRUNED(M, N, O) errorIfPruned(__LINE__, M, N, O) void OpEmitter::genAttrNameGetters() { - // A map of attribute names (including implicit attributes) registered to the - // current operation, to the relative order in which they were registered. - llvm::MapVector attributeNames; - - // Enumerate the attribute names of this op, assigning each a relative - // ordering. - auto addAttrName = [&](StringRef name) { - unsigned index = attributeNames.size(); - attributeNames.insert({name, index}); - }; - for (const NamedAttribute &namedAttr : op.getAttributes()) - addAttrName(namedAttr.name); - // Include key attributes from several traits as implicitly registered. - if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) - addAttrName("operand_segment_sizes"); - if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) - addAttrName("result_segment_sizes"); + const llvm::MapVector &attributes = + emitHelper.getAttrMetadata(); // Emit the getAttributeNames method. { @@ -662,20 +878,18 @@ "::llvm::ArrayRef<::llvm::StringRef>", "getAttributeNames"); ERROR_IF_PRUNED(method, "getAttributeNames", op); auto &body = method->body(); - if (attributeNames.empty()) { + if (attributes.empty()) { body << " return {};"; - } else { - body << " static ::llvm::StringRef attrNames[] = {"; - llvm::interleaveComma(llvm::make_first_range(attributeNames), body, - [&](StringRef attrName) { - body << "::llvm::StringRef(\"" << attrName - << "\")"; - }); - body << "};\n return ::llvm::makeArrayRef(attrNames);"; + // Nothing else to do if there are no registered attributes. Exit early. + return; } + body << " static ::llvm::StringRef attrNames[] = {"; + llvm::interleaveComma(llvm::make_first_range(attributes), body, + [&](StringRef attrName) { + body << "::llvm::StringRef(\"" << attrName << "\")"; + }); + body << "};\n return ::llvm::makeArrayRef(attrNames);"; } - if (attributeNames.empty()) - return; // Emit the getAttributeNameForIndex methods. { @@ -697,14 +911,14 @@ assert(index < {0} && "invalid attribute index"); return name.getRegisteredInfo()->getAttributeNames()[index]; )"; - method->body() << formatv(getAttrName, attributeNames.size()); + method->body() << formatv(getAttrName, attributes.size()); } // Generate the AttrName methods, that expose the attribute names to // users. const char *attrNameMethodBody = " return getAttributeNameForIndex({0});"; - for (const std::pair &attrIt : attributeNames) { - for (StringRef name : op.getGetterNames(attrIt.first)) { + for (auto &attrIt : llvm::enumerate(llvm::make_first_range(attributes))) { + for (StringRef name : op.getGetterNames(attrIt.value())) { std::string methodName = (name + "AttrName").str(); // Generate the non-static variant. @@ -712,7 +926,7 @@ auto *method = opClass.addInlineMethod("::mlir::StringAttr", methodName); ERROR_IF_PRUNED(method, methodName, op); - method->body() << llvm::formatv(attrNameMethodBody, attrIt.second); + method->body() << llvm::formatv(attrNameMethodBody, attrIt.index()); } // Generate the static variant. @@ -722,7 +936,7 @@ MethodParameter("::mlir::OperationName", "name")); ERROR_IF_PRUNED(method, methodName, op); method->body() << llvm::formatv(attrNameMethodBody, - "name, " + Twine(attrIt.second)); + "name, " + Twine(attrIt.index())); } } } @@ -772,12 +986,13 @@ // Generate named accessor with Attribute return type. This is a wrapper class // that allows referring to the attributes via accessors instead of having to // use the string interface for better compile time verification. - auto emitAttrWithStorageType = [&](StringRef name, Attribute attr) { + auto emitAttrWithStorageType = [&](StringRef name, StringRef attrName, + Attribute attr) { auto *method = opClass.addMethod(attr.getStorageType(), name + "Attr"); if (!method) return; method->body() << formatv( - " return {0}.{1}<{2}>();", formatv(opGetAttr, name), + " return {0}.{1}<{2}>();", emitHelper.getAttr(attrName), attr.isOptional() || attr.hasDefaultValue() ? "dyn_cast_or_null" : "cast", attr.getStorageType()); @@ -788,7 +1003,7 @@ if (namedAttr.attr.isDerivedAttr()) { emitDerivedAttr(name, namedAttr.attr); } else { - emitAttrWithStorageType(name, namedAttr.attr); + emitAttrWithStorageType(name, namedAttr.name, namedAttr.attr); emitAttrGetterWithReturnType(fctx, opClass, op, name, namedAttr.attr); } } @@ -1041,8 +1256,8 @@ // array. std::string attrSizeInitCode; if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { - std::string attr = op.getGetterName("operand_segment_sizes"); - attrSizeInitCode = formatv(opSegmentSizeAttrInitCode, attr).str(); + attrSizeInitCode = formatv(opSegmentSizeAttrInitCode, + emitHelper.getAttr(operandSegmentAttrName)); } generateNamedOperandGetters( @@ -1073,17 +1288,17 @@ << " auto mutableRange = " "::mlir::MutableOperandRange(getOperation(), " "range.first, range.second"; - if (attrSizedOperands) - body << ", ::mlir::MutableOperandRange::OperandSegment(" << i - << "u, *getOperation()->getAttrDictionary().getNamed(" - << op.getGetterName("operand_segment_sizes") << "AttrName()))"; + if (attrSizedOperands) { + body << formatv( + ", ::mlir::MutableOperandRange::OperandSegment({0}u, *{1})", i, + emitHelper.getAttr(operandSegmentAttrName, /*isNamed=*/true)); + } body << ");\n"; // If this operand is a nested variadic, we split the range into a // MutableOperandRangeRange that provides a range over all of the // sub-ranges. if (operand.isVariadicOfVariadic()) { - // body << " return " "mutableRange.split(*(*this)->getAttrDictionary().getNamed(" << op.getGetterName( @@ -1129,8 +1344,8 @@ // Build the initializer string for the result segment size attribute. std::string attrSizeInitCode; if (attrSizedResults) { - std::string attr = op.getGetterName("result_segment_sizes"); - attrSizeInitCode = formatv(opSegmentSizeAttrInitCode, attr).str(); + attrSizeInitCode = formatv(opSegmentSizeAttrInitCode, + emitHelper.getAttr(resultSegmentAttrName)); } generateValueRangeStartAndEnd( @@ -1220,7 +1435,7 @@ } } -static bool canGenerateUnwrappedBuilder(Operator &op) { +static bool canGenerateUnwrappedBuilder(const Operator &op) { // If this op does not have native attributes at all, return directly to avoid // redefining builders. if (op.getNumNativeAttributes() == 0) @@ -1232,7 +1447,7 @@ // different from the wrapped mlir::Attribute type to avoid redefining // builders. This checks for the op has at least one such native attribute. for (int i = 0, e = op.getNumNativeAttributes(); i < e; ++i) { - NamedAttribute &namedAttr = op.getAttribute(i); + const NamedAttribute &namedAttr = op.getAttribute(i); if (canUseUnwrappedRawValue(namedAttr.attr)) { canGenerate = true; break; @@ -1241,7 +1456,7 @@ return canGenerate; } -static bool canInferType(Operator &op) { +static bool canInferType(const Operator &op) { return op.getTrait("::mlir::InferTypeOpInterface::Trait"); } @@ -1727,7 +1942,7 @@ const NamedAttribute &namedAttr = *arg.get(); const Attribute &attr = namedAttr.attr; - // inferred attributes don't need to be added to the param list. + // Inferred attributes don't need to be added to the param list. if (inferredAttributes.contains(namedAttr.name)) continue; @@ -1774,7 +1989,7 @@ // Push all operands to the result. for (int i = 0, e = op.getNumOperands(); i < e; ++i) { std::string argName = getArgumentName(op, i); - NamedTypeConstraint &operand = op.getOperand(i); + const NamedTypeConstraint &operand = op.getOperand(i); if (operand.constraint.isVariadicOfVariadic()) { body << " for (::mlir::ValueRange range : " << argName << ")\n " << builderOpState << ".addOperands(range);\n"; @@ -1800,7 +2015,7 @@ // If the operation has the operand segment size attribute, add it here. if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { - std::string sizes = op.getGetterName("operand_segment_sizes"); + std::string sizes = op.getGetterName(operandSegmentAttrName); body << " " << builderOpState << ".addAttribute(" << sizes << "AttrName(" << builderOpState << ".name), " << "odsBuilder.getI32VectorAttr({"; @@ -2164,64 +2379,13 @@ ERROR_IF_PRUNED(method, "print", op); } -/// Generate verification on native traits requiring attributes. -static void genNativeTraitAttrVerifier(MethodBody &body, - const OpOrAdaptorHelper &emitHelper) { - // Check that the variadic segment sizes attribute exists and contains the - // expected number of elements. - // - // {0}: Attribute name. - // {1}: Expected number of elements. - // {2}: "operand" or "result". - // {3}: Attribute getter call. - // {4}: Emit error prefix. - const char *const checkAttrSizedValueSegmentsCode = R"( - { - auto sizeAttr = {3}.dyn_cast<::mlir::DenseIntElementsAttr>(); - if (!sizeAttr) - return {4}"missing segment sizes attribute '{0}'"); - auto numElements = - sizeAttr.getType().cast<::mlir::ShapedType>().getNumElements(); - if (numElements != {1}) - return {4}"'{0}' attribute for specifying {2} segments must have {1} " - "elements, but got ") << numElements; - } - )"; - - // Verify a few traits first so that we can use getODSOperands() and - // getODSResults() in the rest of the verifier. - auto &op = emitHelper.getOp(); - for (auto &trait : op.getTraits()) { - auto *t = dyn_cast(&trait); - if (!t) - continue; - std::string traitName = t->getFullyQualifiedTraitName(); - if (traitName == "::mlir::OpTrait::AttrSizedOperandSegments") { - StringRef attrName = "operand_segment_sizes"; - body << formatv(checkAttrSizedValueSegmentsCode, attrName, - op.getNumOperands(), "operand", - emitHelper.getAttr(attrName), - emitHelper.emitErrorPrefix()); - } else if (traitName == "::mlir::OpTrait::AttrSizedResultSegments") { - StringRef attrName = "result_segment_sizes"; - body << formatv( - checkAttrSizedValueSegmentsCode, attrName, op.getNumResults(), - "result", emitHelper.getAttr(attrName), emitHelper.emitErrorPrefix()); - } - } -} - void OpEmitter::genVerifier() { auto *implMethod = opClass.addMethod("::mlir::LogicalResult", "verifyInvariantsImpl"); ERROR_IF_PRUNED(implMethod, "verifyInvariantsImpl", op); auto &implBody = implMethod->body(); - OpOrAdaptorHelper emitHelper(op, /*isOp=*/true); - genNativeTraitAttrVerifier(implBody, emitHelper); - populateSubstitutions(emitHelper, verifyCtx); - genAttributeVerifier(emitHelper, verifyCtx, implBody, staticVerifierEmitter); genOperandResultVerifier(implBody, op.getOperands(), "operand"); genOperandResultVerifier(implBody, op.getResults(), "result"); @@ -2574,34 +2738,49 @@ // getters identical to those defined in the Op. class OpOperandAdaptorEmitter { public: - static void emitDecl(const Operator &op, - StaticVerifierFunctionEmitter &staticVerifierEmitter, - raw_ostream &os); - static void emitDef(const Operator &op, - StaticVerifierFunctionEmitter &staticVerifierEmitter, - raw_ostream &os); + static void + emitDecl(const Operator &op, + const StaticVerifierFunctionEmitter &staticVerifierEmitter, + raw_ostream &os); + static void + emitDef(const Operator &op, + const StaticVerifierFunctionEmitter &staticVerifierEmitter, + raw_ostream &os); private: explicit OpOperandAdaptorEmitter( - const Operator &op, StaticVerifierFunctionEmitter &staticVerifierEmitter); + const Operator &op, + const StaticVerifierFunctionEmitter &staticVerifierEmitter); // Add verification function. This generates a verify method for the adaptor // which verifies all the op-independent attribute constraints. void addVerification(); + // The operation for which to emit an adaptor. const Operator &op; + + // The generated adaptor class. Class adaptor; - StaticVerifierFunctionEmitter &staticVerifierEmitter; + + // The emitter containing all of the locally emitted verification functions. + const StaticVerifierFunctionEmitter &staticVerifierEmitter; + + // Helper for emitting adaptor code. + OpOrAdaptorHelper emitHelper; }; } // namespace OpOperandAdaptorEmitter::OpOperandAdaptorEmitter( - const Operator &op, StaticVerifierFunctionEmitter &staticVerifierEmitter) + const Operator &op, + const StaticVerifierFunctionEmitter &staticVerifierEmitter) : op(op), adaptor(op.getAdaptorName()), - staticVerifierEmitter(staticVerifierEmitter) { + staticVerifierEmitter(staticVerifierEmitter), + emitHelper(op, /*emitForOp=*/false) { adaptor.addField("::mlir::ValueRange", "odsOperands"); adaptor.addField("::mlir::DictionaryAttr", "odsAttrs"); adaptor.addField("::mlir::RegionRange", "odsRegions"); + adaptor.addField("::llvm::Optional<::mlir::OperationName>", "odsOpName"); + const auto *attrSizedOperands = op.getTrait("::m::OpTrait::AttrSizedOperandSegments"); { @@ -2615,14 +2794,21 @@ constructor->addMemberInitializer("odsOperands", "values"); constructor->addMemberInitializer("odsAttrs", "attrs"); constructor->addMemberInitializer("odsRegions", "regions"); + + MethodBody &body = constructor->body(); + body.indent() << "if (odsAttrs)\n"; + body.indent() << formatv( + "odsOpName.emplace(\"{0}\", odsAttrs.getContext());\n", + op.getOperationName()); } { - auto *constructor = adaptor.addConstructor( - MethodParameter(op.getCppClassName() + " &", "op")); + auto *constructor = + adaptor.addConstructor(MethodParameter(op.getCppClassName(), "op")); constructor->addMemberInitializer("odsOperands", "op->getOperands()"); constructor->addMemberInitializer("odsAttrs", "op->getAttrDictionary()"); constructor->addMemberInitializer("odsRegions", "op->getRegions()"); + constructor->addMemberInitializer("odsOpName", "op->getName()"); } { @@ -2630,8 +2816,11 @@ ERROR_IF_PRUNED(m, "getOperands", op); m->body() << " return odsOperands;"; } - std::string attr = "operand_segment_sizes"; - std::string sizeAttrInit = formatv(adapterSegmentSizeAttrInitCode, attr); + std::string sizeAttrInit; + if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { + sizeAttrInit = formatv(adapterSegmentSizeAttrInitCode, + emitHelper.getAttr(operandSegmentAttrName)); + } generateNamedOperandGetters(op, adaptor, /*isAdaptor=*/true, sizeAttrInit, /*rangeType=*/"::mlir::ValueRange", @@ -2647,15 +2836,13 @@ Attribute attr) { auto *method = adaptor.addMethod(attr.getStorageType(), emitName + "Attr"); ERROR_IF_PRUNED(method, "Adaptor::" + emitName + "Attr", op); - auto &body = method->body(); - body << " assert(odsAttrs && \"no attributes when constructing adapter\");" - << "\n " << attr.getStorageType() << " attr = " - << "odsAttrs.get(\"" << name << "\")."; - if (attr.hasDefaultValue() || attr.isOptional()) - body << "dyn_cast_or_null<"; - else - body << "cast<"; - body << attr.getStorageType() << ">();\n"; + auto &body = method->body().indent(); + body << "assert(odsAttrs && \"no attributes when constructing adapter\");\n" + << formatv("auto attr = {0}.{1}<{2}>();\n", emitHelper.getAttr(name), + attr.hasDefaultValue() || attr.isOptional() + ? "dyn_cast_or_null" + : "cast", + attr.getStorageType()); if (attr.hasDefaultValue()) { // Use the default value if attribute is not set. @@ -2721,24 +2908,23 @@ ERROR_IF_PRUNED(method, "verify", op); auto &body = method->body(); - OpOrAdaptorHelper emitHelper(op, /*isOp=*/false); - genNativeTraitAttrVerifier(body, emitHelper); FmtContext verifyCtx; populateSubstitutions(emitHelper, verifyCtx); - genAttributeVerifier(emitHelper, verifyCtx, body, staticVerifierEmitter); body << " return ::mlir::success();"; } void OpOperandAdaptorEmitter::emitDecl( - const Operator &op, StaticVerifierFunctionEmitter &staticVerifierEmitter, + const Operator &op, + const StaticVerifierFunctionEmitter &staticVerifierEmitter, raw_ostream &os) { OpOperandAdaptorEmitter(op, staticVerifierEmitter).adaptor.writeDeclTo(os); } void OpOperandAdaptorEmitter::emitDef( - const Operator &op, StaticVerifierFunctionEmitter &staticVerifierEmitter, + const Operator &op, + const StaticVerifierFunctionEmitter &staticVerifierEmitter, raw_ostream &os) { OpOperandAdaptorEmitter(op, staticVerifierEmitter).adaptor.writeDefTo(os); }