diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -428,6 +428,23 @@ return findAttrUnsorted(first, last, name); } +/// Get an attribute from a sorted range of named attributes. Returns null if +/// the attribute was not found. +template +Attribute getAttrSorted(IteratorT first, IteratorT last, NameT name) { + std::pair result = findAttrSorted(first, last, name); + return result.second ? result.first->getValue() : Attribute(); +} + +/// Get an attribute from a sorted range of named attributes. Returns None if +/// the attribute was not found. +template +Optional getNamedAttrSorted(IteratorT first, IteratorT last, + NameT name) { + std::pair result = findAttrSorted(first, last, name); + return result.second ? *result.first : Optional(); +} + } // namespace impl //===----------------------------------------------------------------------===// @@ -447,7 +464,7 @@ NamedAttrList() : dictionarySorted({}, true) {} NamedAttrList(ArrayRef attributes); NamedAttrList(DictionaryAttr attributes); - NamedAttrList(const_iterator in_start, const_iterator in_end); + NamedAttrList(const_iterator inStart, const_iterator inEnd); bool operator!=(const NamedAttrList &other) const { return !(*this == other); @@ -478,15 +495,15 @@ typename = std::enable_if_t::iterator_category, std::input_iterator_tag>::value>> - void append(IteratorT in_start, IteratorT in_end) { + void append(IteratorT inStart, IteratorT inEnd) { // TODO: expand to handle case where values appended are in order & after // end of current list. dictionarySorted.setPointerAndInt(nullptr, false); - attrs.append(in_start, in_end); + attrs.append(inStart, inEnd); } /// Replaces the attributes with new list of attributes. - void assign(const_iterator in_start, const_iterator in_end); + void assign(const_iterator inStart, const_iterator inEnd); /// Replaces the attributes with new list of attributes. void assign(ArrayRef range) { diff --git a/mlir/test/Dialect/LLVMIR/global.mlir b/mlir/test/Dialect/LLVMIR/global.mlir --- a/mlir/test/Dialect/LLVMIR/global.mlir +++ b/mlir/test/Dialect/LLVMIR/global.mlir @@ -81,7 +81,7 @@ // ----- // expected-error @+1 {{op requires attribute 'sym_name'}} -"llvm.mlir.global"() ({}) {type = i64, constant, global_type = i64, value = 42 : i64} : () -> () +"llvm.mlir.global"() ({}) {linkage = "private", type = i64, constant, global_type = i64, value = 42 : i64} : () -> () // ----- diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir --- a/mlir/test/IR/attribute.mlir +++ b/mlir/test/IR/attribute.mlir @@ -230,7 +230,8 @@ "test.int_attrs"() { any_i32_attr = 5.0 : f32, si32_attr = 7 : si32, - ui32_attr = 6 : ui32 + ui32_attr = 6 : ui32, + index_attr = 1 : index } : () -> () return } diff --git a/mlir/test/mlir-tblgen/constraint-unique.td b/mlir/test/mlir-tblgen/constraint-unique.td --- a/mlir/test/mlir-tblgen/constraint-unique.td +++ b/mlir/test/mlir-tblgen/constraint-unique.td @@ -116,7 +116,7 @@ /// Test that the uniqued constraints are being used. // CHECK-LABEL: OpA::verify -// CHECK: auto [[$B_ATTR:.*b]] = (*this)->getAttr(bAttrName()); +// CHECK: auto [[$B_ATTR:.*b]] = requiredAttr_b; // CHECK: if (::mlir::failed([[$A_ATTR_CONSTRAINT]](*this, [[$B_ATTR]], "b"))) // CHECK-NEXT: return ::mlir::failure(); // CHECK: auto [[$A_VALUE_GROUP:.*]] = getODSOperands(0); @@ -137,7 +137,7 @@ /// Test that the op with the same predicates but different with descriptions /// uses the different constraints. // CHECK-LABEL: OpC::verify -// CHECK: auto [[$B_ATTR:.*b]] = (*this)->getAttr(bAttrName()); +// CHECK: auto [[$B_ATTR:.*b]] = requiredAttr_b; // CHECK: if (::mlir::failed([[$O_ATTR_CONSTRAINT]](*this, [[$B_ATTR]], "b"))) // CHECK-NEXT: return ::mlir::failure(); // CHECK: auto [[$A_VALUE_GROUP:.*]] = getODSOperands(0); diff --git a/mlir/test/mlir-tblgen/op-attribute.td b/mlir/test/mlir-tblgen/op-attribute.td --- a/mlir/test/mlir-tblgen/op-attribute.td +++ b/mlir/test/mlir-tblgen/op-attribute.td @@ -65,15 +65,20 @@ // --- // DEF: ::mlir::LogicalResult AOpAdaptor::verify -// DEF: auto tblgen_aAttr = odsAttrs.get("aAttr"); -// DEF-NEXT: if (!tblgen_aAttr) -// DEF-NEXT: return emitError(loc, "'test.a_op' op ""requires attribute 'aAttr'"); +// DEF: ::mlir::Attribute requiredAttr_aAttr; +// DEF-NEXT: while (true) { +// DEF-NEXT: if (namedAttrIt == namedAttrRange.end()) +// DEF-NEXT: return emitError(loc, "'test.a_op' op ""requires attribute 'aAttr'"); +// DEF-NEXT: if (namedAttrIt->getName() == AOp::aAttrAttrName(*odsOpName)) { +// DEF-NEXT: requiredAttr_aAttr = namedAttrIt->getValue(); +// DEF-NEXT: break; +// DEF: auto tblgen_aAttr = requiredAttr_aAttr; // DEF: if (tblgen_aAttr && !((some-condition))) // DEF-NEXT: return emitError(loc, "'test.a_op' op ""attribute 'aAttr' failed to satisfy constraint: some attribute kind"); -// DEF: auto tblgen_bAttr = odsAttrs.get("bAttr"); +// DEF: auto tblgen_bAttr = ::mlir::impl::getAttrSorted(odsAttrs.begin() + 1, odsAttrs.end() - 0, AOp::bAttrAttrName(*odsOpName)); // DEF-NEXT: if (tblgen_bAttr && !((some-condition))) // DEF-NEXT: return emitError(loc, "'test.a_op' op ""attribute 'bAttr' failed to satisfy constraint: some attribute kind"); -// DEF: auto tblgen_cAttr = odsAttrs.get("cAttr"); +// DEF: auto tblgen_cAttr = ::mlir::impl::getAttrSorted(odsAttrs.begin() + 1, odsAttrs.end() - 0, AOp::cAttrAttrName(*odsOpName)); // DEF-NEXT: if (tblgen_cAttr && !((some-condition))) // DEF-NEXT: return emitError(loc, "'test.a_op' op ""attribute 'cAttr' failed to satisfy constraint: some attribute kind"); @@ -81,13 +86,13 @@ // --- // DEF: some-attr-kind AOp::aAttrAttr() -// DEF-NEXT: (*this)->getAttr(aAttrAttrName()).cast() +// DEF-NEXT: ::mlir::impl::getAttrSorted((*this)->getAttrs().begin() + 0, (*this)->getAttrs().end() - 0, aAttrAttrName()).cast() // DEF: some-return-type AOp::aAttr() { // DEF-NEXT: auto attr = aAttrAttr() // DEF-NEXT: return attr.some-convert-from-storage(); // DEF: some-attr-kind AOp::bAttrAttr() -// DEF-NEXT: return (*this)->getAttr(bAttrAttrName()).dyn_cast_or_null() +// DEF-NEXT: ::mlir::impl::getAttrSorted((*this)->getAttrs().begin() + 1, (*this)->getAttrs().end() - 0, bAttrAttrName()).dyn_cast_or_null() // DEF: some-return-type AOp::bAttr() { // DEF-NEXT: auto attr = bAttrAttr(); // DEF-NEXT: if (!attr) @@ -95,7 +100,7 @@ // DEF-NEXT: return attr.some-convert-from-storage(); // DEF: some-attr-kind AOp::cAttrAttr() -// DEF-NEXT: return (*this)->getAttr(cAttrAttrName()).dyn_cast_or_null() +// DEF-NEXT: ::mlir::impl::getAttrSorted((*this)->getAttrs().begin() + 1, (*this)->getAttrs().end() - 0, cAttrAttrName()).dyn_cast_or_null() // DEF: ::llvm::Optional AOp::cAttr() { // DEF-NEXT: auto attr = cAttrAttr() // DEF-NEXT: return attr ? ::llvm::Optional(attr.some-convert-from-storage()) : (::llvm::None); @@ -179,15 +184,13 @@ // --- // DEF: ::mlir::LogicalResult AgetOpAdaptor::verify -// DEF: auto tblgen_aAttr = odsAttrs.get("aAttr"); -// DEF-NEXT: if (!tblgen_aAttr) -// DEF-NEXT. return emitError(loc, "'test2.a_get_op' op ""requires attribute 'aAttr'"); +// DEF: auto tblgen_aAttr = requiredAttr_aAttr; // DEF: if (tblgen_aAttr && !((some-condition))) // DEF-NEXT: return emitError(loc, "'test2.a_get_op' op ""attribute 'aAttr' failed to satisfy constraint: some attribute kind"); -// DEF: auto tblgen_bAttr = odsAttrs.get("bAttr"); +// DEF: auto tblgen_bAttr = ::mlir::impl::getAttrSorted({{.*}}); // DEF-NEXT: if (tblgen_bAttr && !((some-condition))) // DEF-NEXT: return emitError(loc, "'test2.a_get_op' op ""attribute 'bAttr' failed to satisfy constraint: some attribute kind"); -// DEF: auto tblgen_cAttr = odsAttrs.get("cAttr"); +// DEF: auto tblgen_cAttr = ::mlir::impl::getAttrSorted({{.*}}); // DEF-NEXT: if (tblgen_cAttr && !((some-condition))) // DEF-NEXT: return emitError(loc, "'test2.a_get_op' op ""attribute 'cAttr' failed to satisfy constraint: some attribute kind"); @@ -195,13 +198,13 @@ // --- // DEF: some-attr-kind AgetOp::getAAttrAttr() -// DEF-NEXT: (*this)->getAttr(getAAttrAttrName()).cast() +// DEF-NEXT: ::mlir::impl::getAttrSorted({{.*}}).cast() // DEF: some-return-type AgetOp::getAAttr() { // DEF-NEXT: auto attr = getAAttrAttr() // DEF-NEXT: return attr.some-convert-from-storage(); // DEF: some-attr-kind AgetOp::getBAttrAttr() -// DEF-NEXT: return (*this)->getAttr(getBAttrAttrName()).dyn_cast_or_null() +// DEF-NEXT: return ::mlir::impl::getAttrSorted({{.*}}).dyn_cast_or_null() // DEF: some-return-type AgetOp::getBAttr() { // DEF-NEXT: auto attr = getBAttrAttr(); // DEF-NEXT: if (!attr) @@ -209,7 +212,7 @@ // DEF-NEXT: return attr.some-convert-from-storage(); // DEF: some-attr-kind AgetOp::getCAttrAttr() -// DEF-NEXT: return (*this)->getAttr(getCAttrAttrName()).dyn_cast_or_null() +// DEF-NEXT: return ::mlir::impl::getAttrSorted({{.*}}).dyn_cast_or_null() // DEF: ::llvm::Optional AgetOp::getCAttr() { // DEF-NEXT: auto attr = getCAttrAttr() // DEF-NEXT: return attr ? ::llvm::Optional(attr.some-convert-from-storage()) : (::llvm::None); diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -43,10 +43,21 @@ static const char *const odsBuilder = "odsBuilder"; static const char *const builderOpState = "odsState"; -/// Code for an Op to lookup an attribute. Uses cached identifiers. +/// The names of the implicit attributes that contain variadic operand and +/// result segment sizes. +static const char *const operandSegmentAttrName = "operand_segment_sizes"; +static const char *const resultSegmentAttrName = "result_segment_sizes"; + +/// Code for an Op to lookup an attribute. Uses cached identifiers and subrange +/// lookup. /// -/// {0}: The attribute's getter name. -static const char *const opGetAttr = "(*this)->getAttr({0}AttrName())"; +/// {0}: Code snippet to get the attribute's name or identifier. +/// {1}: The lower bound on the sorted subrange. +/// {2}: The upper bound on the sorted subrange. +/// {3}: Code snippet to get the array of named attributes. +/// {4}: "Named" to get the named attribute. +static const char *const subrangeGetAttr = + "::mlir::impl::get{4}AttrSorted({3}.begin() + {1}, {3}.end() - {2}, {0})"; /// The logic to calculate the actual value range for a declared operand/result /// of an op with variadic operands/results. Note that this logic is not for @@ -80,16 +91,6 @@ /// of an op with variadic operands/results. Note that this logic is assumes /// the op has an attribute specifying the size of each operand/result segment /// (variadic or not). -/// -/// {0}: The name of the attribute specifying the segment sizes. -static const char *const adapterSegmentSizeAttrInitCode = R"( - assert(odsAttrs && "missing segment size attribute for op"); - auto sizeAttr = odsAttrs.get("{0}").cast<::mlir::DenseIntElementsAttr>(); -)"; -static const char *const opSegmentSizeAttrInitCode = R"( - auto sizeAttr = - (*this)->getAttr({0}AttrName()).cast<::mlir::DenseIntElementsAttr>(); -)"; static const char *const attrSizedSegmentValueRangeCalcCode = R"( const uint32_t *sizeAttrValueIt = &*sizeAttr.value_begin(); if (sizeAttr.isSplat()) @@ -100,6 +101,19 @@ start += sizeAttrValueIt[i]; return {start, sizeAttrValueIt[index]}; )"; +/// The code snippet to initialize the sizes for the value range calculation. +/// +/// {0}: The code to get the attribute. +static const char *const adapterSegmentSizeAttrInitCode = R"( + assert(odsAttrs && "missing segment size attribute for op"); + auto sizeAttr = {0}.cast<::mlir::DenseIntElementsAttr>(); +)"; +/// The code snippet to initialize the sizes for the value range calculation. +/// +/// {0}: The code to get the attribute. +static const char *const opSegmentSizeAttrInitCode = R"( + auto sizeAttr = {0}.cast<::mlir::DenseIntElementsAttr>(); +)"; /// The logic to calculate the actual value range for a declared operand /// of an op with variadic of variadic operands within the OpAdaptor. @@ -179,11 +193,30 @@ } namespace { +// Metadata on a registered attribute. Because attributes are stored in sorted +// order on operations, we can use information from ODS to deduce the number +// of required attributes less and and greater than each attribute, allowing +// us to search only a subrange of the attributes in ODS-generated getters. +struct AttributeMetadata { + // The attribute name. + StringRef attrName; + // Whether the attribute is required. + bool isRequired; + // The ODS attribute constraint. Not present for implicit attributes. + Optional constraint; + // The number of required attributes less than this attribute. + unsigned lowerBound; + // The number of required attributes greater than this attribute. + unsigned upperBound; +}; + /// Helper class to select between OpAdaptor and Op code templates. class OpOrAdaptorHelper { public: OpOrAdaptorHelper(const Operator &op, bool emitForOp) - : op(op), emitForOp(emitForOp) {} + : op(op), emitForOp(emitForOp) { + computeAttrMetadata(); + } /// Object that wraps a functor in a stream operator for interop with /// llvm::formatv. @@ -208,14 +241,31 @@ }; // Generate code for getting an attribute. - Formatter getAttr(StringRef attrName) const { + Formatter getAttr(StringRef attrName, bool named = false) const { + assert(attrMetadata.count(attrName) && "expected attribute metadata"); + return [this, attrName, named](raw_ostream &os) -> raw_ostream & { + const AttributeMetadata &attr = attrMetadata.find(attrName)->second; + return os << formatv(subrangeGetAttr, getAttrName(attrName), + attr.lowerBound, attr.upperBound, getAttrRange(), + named ? "Named" : ""); + }; + } + + // Generate code for getting the name of an attribute. + Formatter getAttrName(StringRef attrName) const { return [this, attrName](raw_ostream &os) -> raw_ostream & { - if (!emitForOp) - return os << formatv("odsAttrs.get(\"{0}\")", attrName); - return os << formatv(opGetAttr, op.getGetterName(attrName)); + if (emitForOp) + return os << op.getGetterName(attrName) << "AttrName()"; + return os << formatv("{0}::{1}AttrName(*odsOpName)", op.getCppClassName(), + op.getGetterName(attrName)); }; } + // Get the code snippet for getting the named attribute range. + StringRef getAttrRange() const { + return emitForOp ? "(*this)->getAttrs()" : "odsAttrs"; + } + // Get the prefix code for emitting an error. Formatter emitErrorPrefix() const { return [this](raw_ostream &os) -> raw_ostream & { @@ -254,14 +304,74 @@ // Return the ODS operation wrapper. const Operator &getOp() const { return op; } + // Get the attribute metadata sorted by name. + const llvm::MapVector &getAttrMetadata() const { + return attrMetadata; + } + private: + // Compute the attribute metadata. + void computeAttrMetadata(); + // The operation ODS wrapper. const Operator &op; // True if code is being generate for an op. False for an adaptor. const bool emitForOp; + + // The attribute metadata, mapped by name. + llvm::MapVector attrMetadata; + // The number of required attributes. + unsigned numRequired; }; + } // namespace +void OpOrAdaptorHelper::computeAttrMetadata() { + // Enumerate the attribute names of this op, ensuring the attribute names are + // unique in case implicit attributes are explicitly registered. + for (const NamedAttribute &namedAttr : op.getAttributes()) { + Attribute attr = namedAttr.attr; + bool isOptional = + attr.hasDefaultValue() || attr.isOptional() || attr.isDerivedAttr(); + attrMetadata.insert( + {namedAttr.name, AttributeMetadata{namedAttr.name, !isOptional, attr}}); + } + // Include key attributes from several traits as implicitly registered. + if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { + attrMetadata.insert( + {operandSegmentAttrName, + AttributeMetadata{operandSegmentAttrName, /*isRequired=*/true, + /*attr=*/llvm::None}}); + } + if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) { + attrMetadata.insert( + {resultSegmentAttrName, + AttributeMetadata{resultSegmentAttrName, /*isRequired=*/true, + /*attr=*/llvm::None}}); + } + + // Store the metadata in sorted order. + SmallVector sortedAttrMetadata = + llvm::to_vector(llvm::make_second_range(attrMetadata.takeVector())); + llvm::sort(sortedAttrMetadata, + [](const AttributeMetadata &lhs, const AttributeMetadata &rhs) { + return lhs.attrName < rhs.attrName; + }); + + // Compute the subrange bounds for each attribute. + numRequired = 0; + for (AttributeMetadata &attr : sortedAttrMetadata) { + attr.lowerBound = numRequired; + numRequired += attr.isRequired; + }; + for (AttributeMetadata &attr : sortedAttrMetadata) + attr.upperBound = numRequired - attr.lowerBound - attr.isRequired; + + // Store the results back into the map. + for (const AttributeMetadata &attr : sortedAttrMetadata) + attrMetadata.insert({attr.attrName, attr}); +} + //===----------------------------------------------------------------------===// // Op emitter //===----------------------------------------------------------------------===// @@ -448,6 +558,9 @@ // The emitter containing all of the locally emitted verification functions. const StaticVerifierFunctionEmitter &staticVerifierEmitter; + + // Helper for emitting op code. + OpOrAdaptorHelper emitHelper; }; } // namespace @@ -476,20 +589,107 @@ } } +/// Generate verification on native traits requiring attributes. +static void genNativeTraitAttrVerifier(MethodBody &body, + const OpOrAdaptorHelper &emitHelper) { + // Check that the variadic segment sizes attribute exists and contains the + // expected number of elements. + // + // {0}: Attribute name. + // {1}: Expected number of elements. + // {2}: "operand" or "result". + // {3}: Emit error prefix. + const char *const checkAttrSizedValueSegmentsCode = R"( + { + auto sizeAttr = requiredAttr_{0}.cast<::mlir::DenseIntElementsAttr>(); + auto numElements = + sizeAttr.getType().cast<::mlir::ShapedType>().getNumElements(); + if (numElements != {1}) + return {3}"'{0}' attribute for specifying {2} segments must have {1} " + "elements, but got ") << numElements; + } + )"; + + // Verify a few traits first so that we can use getODSOperands() and + // getODSResults() in the rest of the verifier. + auto &op = emitHelper.getOp(); + for (auto &trait : op.getTraits()) { + auto *t = dyn_cast(&trait); + if (!t) + continue; + std::string traitName = t->getFullyQualifiedTraitName(); + if (traitName == "::mlir::OpTrait::AttrSizedOperandSegments") { + body << formatv(checkAttrSizedValueSegmentsCode, operandSegmentAttrName, + op.getNumOperands(), "operand", + emitHelper.emitErrorPrefix()); + } else if (traitName == "::mlir::OpTrait::AttrSizedResultSegments") { + body << formatv(checkAttrSizedValueSegmentsCode, resultSegmentAttrName, + op.getNumResults(), "result", + emitHelper.emitErrorPrefix()); + } + } +} + // Generate attribute verification. If an op instance is not available, then // attribute checks that require one will not be emitted. +// +// Attribute verification is performed as follows: +// +// 1. Verify that all required attributes are present in sorted order. This +// ensures that we can use subrange lookup even with potentially missing +// attributes. +// 2. Verify native trait attributes so that other attributes may call methods +// that depend on the validity of these attributes, e.g. segment size attributes +// and operand or result getters. +// 3. Verify the constraints on all present attributes. static void genAttributeVerifier( const OpOrAdaptorHelper &emitHelper, FmtContext &ctx, MethodBody &body, const StaticVerifierFunctionEmitter &staticVerifierEmitter) { - // Check that a required attribute exists. - // - // {0}: Attribute variable name. - // {1}: Emit error prefix. - // {2}: Attribute name. - const char *const verifyRequiredAttr = R"( - if (!{0}) - return {1}"requires attribute '{2}'"); + // Check that all required attributes are present. + auto requiredAttributes = llvm::make_filter_range( + emitHelper.getAttrMetadata(), + [](const std::pair &metadata) { + return metadata.second.isRequired; + }); + + if (!requiredAttributes.empty()) { + body.indent(); + + // Since the attributes are sorted, we can check required attributes by + // traversing the array once. + body << formatv("auto namedAttrRange = {0};\n", emitHelper.getAttrRange()); + body << "auto namedAttrIt = namedAttrRange.begin();\n"; + + // Traverse the array until the required attribute is found. Return an error + // if the traversal reached the end. + // + // {0}: Code to get the name of the attribute. + // {1}: The emit error prefix. + // {2}: The name of the attribute. + const char *const findRequiredAttr = R"( +::mlir::Attribute requiredAttr_{2}; +while (true) { + if (namedAttrIt == namedAttrRange.end()) + return {1}"requires attribute '{2}'"); + if (namedAttrIt->getName() == {0}) {{ + requiredAttr_{2} = namedAttrIt->getValue(); + break; + } + ++namedAttrIt; +} )"; + for (const AttributeMetadata &requiredAttr : + llvm::make_second_range(requiredAttributes)) { + body << formatv(findRequiredAttr, + emitHelper.getAttrName(requiredAttr.attrName), + emitHelper.emitErrorPrefix(), requiredAttr.attrName); + } + + body.unindent(); + } + + genNativeTraitAttrVerifier(body, emitHelper); + // Verify the attribute if it is present. This assumes that default values // are valid. This code snippet pastes the condition inline. // @@ -521,7 +721,6 @@ if (attr.isDerivedAttr()) continue; - bool allowMissingAttr = attr.hasDefaultValue() || attr.isOptional(); auto attrPred = attr.getPredicate(); std::string condition = attrPred.isNull() ? "" : attrPred.getCondition(); // If the attribute's condition needs an op but none is available, then the @@ -535,16 +734,17 @@ // If the attribute is not required and we cannot emit the condition, then // there is nothing to be done. - if (allowMissingAttr && !canEmitCondition) + if (!canEmitCondition) continue; - body << formatv(" {\n auto {0} = {1};", varName, - emitHelper.getAttr(attrName)); - - if (!allowMissingAttr) { - body << formatv(verifyRequiredAttr, varName, emitHelper.emitErrorPrefix(), + if (!attr.hasDefaultValue() && !attr.isOptional()) { + body << formatv(" {\n auto {0} = requiredAttr_{1};", varName, attrName); + } else { + body << formatv(" {\n auto {0} = {1};", varName, + emitHelper.getAttr(attrName)); } + if (canEmitCondition) { Optional constraintFn; if (emitHelper.isEmittingForOp() && @@ -573,7 +773,8 @@ : def(op.getDef()), op(op), opClass(op.getCppClassName(), op.getExtraClassDeclaration(), formatExtraDefinitions(op)), - staticVerifierEmitter(staticVerifierEmitter) { + staticVerifierEmitter(staticVerifierEmitter), + emitHelper(op, /*emitForOp=*/true) { verifyCtx.withOp("(*this->getOperation())"); verifyCtx.addSubst("_ctxt", "this->getOperation()->getContext()"); @@ -635,26 +836,12 @@ op.getOperationName() + " (from line " + Twine(line) + ")"); } + #define ERROR_IF_PRUNED(M, N, O) errorIfPruned(__LINE__, M, N, O) void OpEmitter::genAttrNameGetters() { - // A map of attribute names (including implicit attributes) registered to the - // current operation, to the relative order in which they were registered. - llvm::MapVector attributeNames; - - // Enumerate the attribute names of this op, assigning each a relative - // ordering. - auto addAttrName = [&](StringRef name) { - unsigned index = attributeNames.size(); - attributeNames.insert({name, index}); - }; - for (const NamedAttribute &namedAttr : op.getAttributes()) - addAttrName(namedAttr.name); - // Include key attributes from several traits as implicitly registered. - if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) - addAttrName("operand_segment_sizes"); - if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) - addAttrName("result_segment_sizes"); + const llvm::MapVector &attributes = + emitHelper.getAttrMetadata(); // Emit the getAttributeNames method. { @@ -662,20 +849,18 @@ "::llvm::ArrayRef<::llvm::StringRef>", "getAttributeNames"); ERROR_IF_PRUNED(method, "getAttributeNames", op); auto &body = method->body(); - if (attributeNames.empty()) { + if (attributes.empty()) { body << " return {};"; - } else { - body << " static ::llvm::StringRef attrNames[] = {"; - llvm::interleaveComma(llvm::make_first_range(attributeNames), body, - [&](StringRef attrName) { - body << "::llvm::StringRef(\"" << attrName - << "\")"; - }); - body << "};\n return ::llvm::makeArrayRef(attrNames);"; + // Nothing else to do if there are no registered attributes. Exit early. + return; } + body << " static ::llvm::StringRef attrNames[] = {"; + llvm::interleaveComma(llvm::make_first_range(attributes), body, + [&](StringRef attrName) { + body << "::llvm::StringRef(\"" << attrName << "\")"; + }); + body << "};\n return ::llvm::makeArrayRef(attrNames);"; } - if (attributeNames.empty()) - return; // Emit the getAttributeNameForIndex methods. { @@ -697,14 +882,14 @@ assert(index < {0} && "invalid attribute index"); return name.getRegisteredInfo()->getAttributeNames()[index]; )"; - method->body() << formatv(getAttrName, attributeNames.size()); + method->body() << formatv(getAttrName, attributes.size()); } // Generate the AttrName methods, that expose the attribute names to // users. const char *attrNameMethodBody = " return getAttributeNameForIndex({0});"; - for (const std::pair &attrIt : attributeNames) { - for (StringRef name : op.getGetterNames(attrIt.first)) { + for (auto &attrIt : llvm::enumerate(llvm::make_first_range(attributes))) { + for (StringRef name : op.getGetterNames(attrIt.value())) { std::string methodName = (name + "AttrName").str(); // Generate the non-static variant. @@ -712,7 +897,7 @@ auto *method = opClass.addInlineMethod("::mlir::StringAttr", methodName); ERROR_IF_PRUNED(method, methodName, op); - method->body() << llvm::formatv(attrNameMethodBody, attrIt.second); + method->body() << llvm::formatv(attrNameMethodBody, attrIt.index()); } // Generate the static variant. @@ -722,7 +907,7 @@ MethodParameter("::mlir::OperationName", "name")); ERROR_IF_PRUNED(method, methodName, op); method->body() << llvm::formatv(attrNameMethodBody, - "name, " + Twine(attrIt.second)); + "name, " + Twine(attrIt.index())); } } } @@ -772,12 +957,13 @@ // Generate named accessor with Attribute return type. This is a wrapper class // that allows referring to the attributes via accessors instead of having to // use the string interface for better compile time verification. - auto emitAttrWithStorageType = [&](StringRef name, Attribute attr) { + auto emitAttrWithStorageType = [&](StringRef name, StringRef attrName, + Attribute attr) { auto *method = opClass.addMethod(attr.getStorageType(), name + "Attr"); if (!method) return; method->body() << formatv( - " return {0}.{1}<{2}>();", formatv(opGetAttr, name), + " return {0}.{1}<{2}>();", emitHelper.getAttr(attrName), attr.isOptional() || attr.hasDefaultValue() ? "dyn_cast_or_null" : "cast", attr.getStorageType()); @@ -788,7 +974,7 @@ if (namedAttr.attr.isDerivedAttr()) { emitDerivedAttr(name, namedAttr.attr); } else { - emitAttrWithStorageType(name, namedAttr.attr); + emitAttrWithStorageType(name, namedAttr.name, namedAttr.attr); emitAttrGetterWithReturnType(fctx, opClass, op, name, namedAttr.attr); } } @@ -1041,8 +1227,8 @@ // array. std::string attrSizeInitCode; if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { - std::string attr = op.getGetterName("operand_segment_sizes"); - attrSizeInitCode = formatv(opSegmentSizeAttrInitCode, attr).str(); + attrSizeInitCode = formatv(opSegmentSizeAttrInitCode, + emitHelper.getAttr(operandSegmentAttrName)); } generateNamedOperandGetters( @@ -1073,17 +1259,17 @@ << " auto mutableRange = " "::mlir::MutableOperandRange(getOperation(), " "range.first, range.second"; - if (attrSizedOperands) - body << ", ::mlir::MutableOperandRange::OperandSegment(" << i - << "u, *getOperation()->getAttrDictionary().getNamed(" - << op.getGetterName("operand_segment_sizes") << "AttrName()))"; + if (attrSizedOperands) { + body << formatv( + ", ::mlir::MutableOperandRange::OperandSegment({0}u, *{1})", i, + emitHelper.getAttr(operandSegmentAttrName, /*named=*/true)); + } body << ");\n"; // If this operand is a nested variadic, we split the range into a // MutableOperandRangeRange that provides a range over all of the // sub-ranges. if (operand.isVariadicOfVariadic()) { - // body << " return " "mutableRange.split(*(*this)->getAttrDictionary().getNamed(" << op.getGetterName( @@ -1129,8 +1315,8 @@ // Build the initializer string for the result segment size attribute. std::string attrSizeInitCode; if (attrSizedResults) { - std::string attr = op.getGetterName("result_segment_sizes"); - attrSizeInitCode = formatv(opSegmentSizeAttrInitCode, attr).str(); + attrSizeInitCode = formatv(opSegmentSizeAttrInitCode, + emitHelper.getAttr(resultSegmentAttrName)); } generateValueRangeStartAndEnd( @@ -1724,7 +1910,7 @@ const NamedAttribute &namedAttr = *arg.get(); const Attribute &attr = namedAttr.attr; - // inferred attributes don't need to be added to the param list. + // Inferred attributes don't need to be added to the param list. if (inferredAttributes.contains(namedAttr.name)) continue; @@ -1797,7 +1983,7 @@ // If the operation has the operand segment size attribute, add it here. if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { - std::string sizes = op.getGetterName("operand_segment_sizes"); + std::string sizes = op.getGetterName(operandSegmentAttrName); body << " " << builderOpState << ".addAttribute(" << sizes << "AttrName(" << builderOpState << ".name), " << "odsBuilder.getI32VectorAttr({"; @@ -2160,64 +2346,13 @@ ERROR_IF_PRUNED(method, "print", op); } -/// Generate verification on native traits requiring attributes. -static void genNativeTraitAttrVerifier(MethodBody &body, - const OpOrAdaptorHelper &emitHelper) { - // Check that the variadic segment sizes attribute exists and contains the - // expected number of elements. - // - // {0}: Attribute name. - // {1}: Expected number of elements. - // {2}: "operand" or "result". - // {3}: Attribute getter call. - // {4}: Emit error prefix. - const char *const checkAttrSizedValueSegmentsCode = R"( - { - auto sizeAttr = {3}.dyn_cast<::mlir::DenseIntElementsAttr>(); - if (!sizeAttr) - return {4}"missing segment sizes attribute '{0}'"); - auto numElements = - sizeAttr.getType().cast<::mlir::ShapedType>().getNumElements(); - if (numElements != {1}) - return {4}"'{0}' attribute for specifying {2} segments must have {1} " - "elements, but got ") << numElements; - } - )"; - - // Verify a few traits first so that we can use getODSOperands() and - // getODSResults() in the rest of the verifier. - auto &op = emitHelper.getOp(); - for (auto &trait : op.getTraits()) { - auto *t = dyn_cast(&trait); - if (!t) - continue; - std::string traitName = t->getFullyQualifiedTraitName(); - if (traitName == "::mlir::OpTrait::AttrSizedOperandSegments") { - StringRef attrName = "operand_segment_sizes"; - body << formatv(checkAttrSizedValueSegmentsCode, attrName, - op.getNumOperands(), "operand", - emitHelper.getAttr(attrName), - emitHelper.emitErrorPrefix()); - } else if (traitName == "::mlir::OpTrait::AttrSizedResultSegments") { - StringRef attrName = "result_segment_sizes"; - body << formatv( - checkAttrSizedValueSegmentsCode, attrName, op.getNumResults(), - "result", emitHelper.getAttr(attrName), emitHelper.emitErrorPrefix()); - } - } -} - void OpEmitter::genVerifier() { auto *implMethod = opClass.addMethod("::mlir::LogicalResult", "verifyInvariantsImpl"); ERROR_IF_PRUNED(implMethod, "verifyInvariantsImpl", op); auto &implBody = implMethod->body(); - OpOrAdaptorHelper emitHelper(op, /*isOp=*/true); - genNativeTraitAttrVerifier(implBody, emitHelper); - populateSubstitutions(emitHelper, verifyCtx); - genAttributeVerifier(emitHelper, verifyCtx, implBody, staticVerifierEmitter); genOperandResultVerifier(implBody, op.getOperands(), "operand"); genOperandResultVerifier(implBody, op.getResults(), "result"); @@ -2570,34 +2705,49 @@ // getters identical to those defined in the Op. class OpOperandAdaptorEmitter { public: - static void emitDecl(const Operator &op, - StaticVerifierFunctionEmitter &staticVerifierEmitter, - raw_ostream &os); - static void emitDef(const Operator &op, - StaticVerifierFunctionEmitter &staticVerifierEmitter, - raw_ostream &os); + static void + emitDecl(const Operator &op, + const StaticVerifierFunctionEmitter &staticVerifierEmitter, + raw_ostream &os); + static void + emitDef(const Operator &op, + const StaticVerifierFunctionEmitter &staticVerifierEmitter, + raw_ostream &os); private: explicit OpOperandAdaptorEmitter( - const Operator &op, StaticVerifierFunctionEmitter &staticVerifierEmitter); + const Operator &op, + const StaticVerifierFunctionEmitter &staticVerifierEmitter); // Add verification function. This generates a verify method for the adaptor // which verifies all the op-independent attribute constraints. void addVerification(); + // The operation for which to emit an adaptor. const Operator &op; + + // The generated adaptor class. Class adaptor; - StaticVerifierFunctionEmitter &staticVerifierEmitter; + + // The emitter containing all of the locally emitted verification functions. + const StaticVerifierFunctionEmitter &staticVerifierEmitter; + + // Helper for emitting adaptor code. + OpOrAdaptorHelper emitHelper; }; } // namespace OpOperandAdaptorEmitter::OpOperandAdaptorEmitter( - const Operator &op, StaticVerifierFunctionEmitter &staticVerifierEmitter) + const Operator &op, + const StaticVerifierFunctionEmitter &staticVerifierEmitter) : op(op), adaptor(op.getAdaptorName()), - staticVerifierEmitter(staticVerifierEmitter) { + staticVerifierEmitter(staticVerifierEmitter), + emitHelper(op, /*emitForOp=*/false) { adaptor.addField("::mlir::ValueRange", "odsOperands"); adaptor.addField("::mlir::DictionaryAttr", "odsAttrs"); adaptor.addField("::mlir::RegionRange", "odsRegions"); + adaptor.addField("::llvm::Optional<::mlir::OperationName>", "odsOpName"); + const auto *attrSizedOperands = op.getTrait("::m::OpTrait::AttrSizedOperandSegments"); { @@ -2611,14 +2761,21 @@ constructor->addMemberInitializer("odsOperands", "values"); constructor->addMemberInitializer("odsAttrs", "attrs"); constructor->addMemberInitializer("odsRegions", "regions"); + + MethodBody &body = constructor->body(); + body.indent() << "if (odsAttrs)\n"; + body.indent() << formatv( + "odsOpName.emplace(\"{0}\", odsAttrs.getContext());\n", + op.getOperationName()); } { - auto *constructor = adaptor.addConstructor( - MethodParameter(op.getCppClassName() + " &", "op")); + auto *constructor = + adaptor.addConstructor(MethodParameter(op.getCppClassName(), "op")); constructor->addMemberInitializer("odsOperands", "op->getOperands()"); constructor->addMemberInitializer("odsAttrs", "op->getAttrDictionary()"); constructor->addMemberInitializer("odsRegions", "op->getRegions()"); + constructor->addMemberInitializer("odsOpName", "op->getName()"); } { @@ -2626,8 +2783,11 @@ ERROR_IF_PRUNED(m, "getOperands", op); m->body() << " return odsOperands;"; } - std::string attr = "operand_segment_sizes"; - std::string sizeAttrInit = formatv(adapterSegmentSizeAttrInitCode, attr); + std::string sizeAttrInit; + if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { + sizeAttrInit = formatv(adapterSegmentSizeAttrInitCode, + emitHelper.getAttr(operandSegmentAttrName)); + } generateNamedOperandGetters(op, adaptor, /*isAdaptor=*/true, sizeAttrInit, /*rangeType=*/"::mlir::ValueRange", @@ -2643,15 +2803,13 @@ Attribute attr) { auto *method = adaptor.addMethod(attr.getStorageType(), emitName + "Attr"); ERROR_IF_PRUNED(method, "Adaptor::" + emitName + "Attr", op); - auto &body = method->body(); - body << " assert(odsAttrs && \"no attributes when constructing adapter\");" - << "\n " << attr.getStorageType() << " attr = " - << "odsAttrs.get(\"" << name << "\")."; - if (attr.hasDefaultValue() || attr.isOptional()) - body << "dyn_cast_or_null<"; - else - body << "cast<"; - body << attr.getStorageType() << ">();\n"; + auto &body = method->body().indent(); + body << "assert(odsAttrs && \"no attributes when constructing adapter\");\n" + << formatv("auto attr = {0}.{1}<{2}>();\n", emitHelper.getAttr(name), + attr.hasDefaultValue() || attr.isOptional() + ? "dyn_cast_or_null" + : "cast", + attr.getStorageType()); if (attr.hasDefaultValue()) { // Use the default value if attribute is not set. @@ -2717,24 +2875,23 @@ ERROR_IF_PRUNED(method, "verify", op); auto &body = method->body(); - OpOrAdaptorHelper emitHelper(op, /*isOp=*/false); - genNativeTraitAttrVerifier(body, emitHelper); FmtContext verifyCtx; populateSubstitutions(emitHelper, verifyCtx); - genAttributeVerifier(emitHelper, verifyCtx, body, staticVerifierEmitter); body << " return ::mlir::success();"; } void OpOperandAdaptorEmitter::emitDecl( - const Operator &op, StaticVerifierFunctionEmitter &staticVerifierEmitter, + const Operator &op, + const StaticVerifierFunctionEmitter &staticVerifierEmitter, raw_ostream &os) { OpOperandAdaptorEmitter(op, staticVerifierEmitter).adaptor.writeDeclTo(os); } void OpOperandAdaptorEmitter::emitDef( - const Operator &op, StaticVerifierFunctionEmitter &staticVerifierEmitter, + const Operator &op, + const StaticVerifierFunctionEmitter &staticVerifierEmitter, raw_ostream &os) { OpOperandAdaptorEmitter(op, staticVerifierEmitter).adaptor.writeDefTo(os); }