diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h @@ -77,6 +77,7 @@ AffineMapAccessInterface::Trait> { public: using Op::Op; + static ArrayRef getAttributeNames() { return {}; } static void build(OpBuilder &builder, OperationState &result, Value srcMemRef, AffineMap srcMap, ValueRange srcIndices, Value destMemRef, @@ -268,6 +269,7 @@ AffineMapAccessInterface::Trait> { public: using Op::Op; + static ArrayRef getAttributeNames() { return {}; } static void build(OpBuilder &builder, OperationState &result, Value tagMemRef, AffineMap tagMap, ValueRange tagIndices, Value numElements); diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h @@ -92,6 +92,7 @@ : public Op { public: using Op::Op; + static ArrayRef getAttributeNames() { return {}; } static void build(OpBuilder &builder, OperationState &result, Value srcMemRef, ValueRange srcIndices, Value destMemRef, @@ -215,6 +216,7 @@ : public Op { public: using Op::Op; + static ArrayRef getAttributeNames() { return {}; } static void build(OpBuilder &builder, OperationState &result, Value tagMemRef, ValueRange tagIndices, Value numElements); diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -172,7 +172,7 @@ T::getParseAssemblyFn(), T::getPrintAssemblyFn(), T::getVerifyInvariantsFn(), T::getFoldHookFn(), T::getGetCanonicalizationPatternsFn(), T::getInterfaceMap(), - T::getHasTraitFn()); + T::getHasTraitFn(), T::getAttributeNames()); } /// Register a new operation in a Dialect object. @@ -183,7 +183,11 @@ ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly, VerifyInvariantsFn &&verifyInvariants, FoldHookFn &&foldHook, GetCanonicalizationPatternsFn &&getCanonicalizationPatterns, - detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait); + detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait, + ArrayRef attrNames); + + /// Return the list of cached attribute names registered to this operation. + ArrayRef getAttributeNames() const { return attributeNames; } private: AbstractOperation(StringRef name, Dialect &dialect, TypeID typeID, @@ -192,7 +196,8 @@ VerifyInvariantsFn &&verifyInvariants, FoldHookFn &&foldHook, GetCanonicalizationPatternsFn &&getCanonicalizationPatterns, - detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait); + detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait, + ArrayRef attrNames); /// A map of interfaces that were registered to this operation. detail::InterfaceMap interfaceMap; @@ -204,6 +209,11 @@ ParseAssemblyFn parseAssemblyFn; PrintAssemblyFn printAssemblyFn; VerifyInvariantsFn verifyInvariantsFn; + + /// A list of attribute names registered to this operation in identifier form. + /// This allows for operation classes to use identifiers for attribute + /// lookup/creation/etc., as opposed to strings. + ArrayRef attributeNames; }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -702,17 +702,30 @@ ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly, VerifyInvariantsFn &&verifyInvariants, FoldHookFn &&foldHook, GetCanonicalizationPatternsFn &&getCanonicalizationPatterns, - detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait) { - AbstractOperation opInfo(name, dialect, typeID, std::move(parseAssembly), - std::move(printAssembly), - std::move(verifyInvariants), std::move(foldHook), - std::move(getCanonicalizationPatterns), - std::move(interfaceMap), std::move(hasTrait)); - - auto &impl = dialect.getContext()->getImpl(); + detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait, + ArrayRef attrNames) { + MLIRContext *ctx = dialect.getContext(); + auto &impl = ctx->getImpl(); assert(impl.multiThreadedExecutionContext == 0 && "Registering a new operation kind while in a multi-threaded execution " "context"); + + // Register the attribute names of this operation; + MutableArrayRef identifierAttrNames; + if (!attrNames.empty()) { + identifierAttrNames = MutableArrayRef( + impl.identifierAllocator.Allocate(attrNames.size()), + attrNames.size()); + for (unsigned i : llvm::seq(0, attrNames.size())) + identifierAttrNames[i] = Identifier::get(attrNames[i], ctx); + } + + // Register the information for this operation. + AbstractOperation opInfo( + name, dialect, typeID, std::move(parseAssembly), std::move(printAssembly), + std::move(verifyInvariants), std::move(foldHook), + std::move(getCanonicalizationPatterns), std::move(interfaceMap), + std::move(hasTrait), identifierAttrNames); if (!impl.registeredOperations.insert({name, std::move(opInfo)}).second) { llvm::errs() << "error: operation named '" << name << "' is already registered.\n"; @@ -725,7 +738,8 @@ ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly, VerifyInvariantsFn &&verifyInvariants, FoldHookFn &&foldHook, GetCanonicalizationPatternsFn &&getCanonicalizationPatterns, - detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait) + detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait, + ArrayRef attrNames) : name(Identifier::get(name, dialect.getContext())), dialect(dialect), typeID(typeID), interfaceMap(std::move(interfaceMap)), foldHookFn(std::move(foldHook)), @@ -733,7 +747,8 @@ hasTraitFn(std::move(hasTrait)), parseAssemblyFn(std::move(parseAssembly)), printAssemblyFn(std::move(printAssembly)), - verifyInvariantsFn(std::move(verifyInvariants)) {} + verifyInvariantsFn(std::move(verifyInvariants)), + attributeNames(attrNames) {} //===----------------------------------------------------------------------===// // AbstractType diff --git a/mlir/lib/TableGen/OpClass.cpp b/mlir/lib/TableGen/OpClass.cpp --- a/mlir/lib/TableGen/OpClass.cpp +++ b/mlir/lib/TableGen/OpClass.cpp @@ -201,15 +201,15 @@ os.indent(2); if (isStatic()) os << "static "; - if (properties & MP_Constexpr) + if ((properties & MP_Constexpr) == MP_Constexpr) os << "constexpr "; methodSignature.writeDeclTo(os); - if (!isInline()) + if (!isInline()) { os << ";"; - else { + } else { os << " {\n"; - methodBody.writeTo(os); - os << "}"; + methodBody.writeTo(os.indent(2)); + os.indent(2) << "}"; } } diff --git a/mlir/test/mlir-tblgen/op-attribute.td b/mlir/test/mlir-tblgen/op-attribute.td --- a/mlir/test/mlir-tblgen/op-attribute.td +++ b/mlir/test/mlir-tblgen/op-attribute.td @@ -28,6 +28,31 @@ ); } +// DECL-LABEL: AOp declarations + +// Test attribute name methods +// --- + +// DECL: static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() +// DECL-NEXT: static ::llvm::StringRef attrNames[] = +// DECL-SAME: {::llvm::StringRef("aAttr"), ::llvm::StringRef("bAttr"), ::llvm::StringRef("cAttr")}; +// DECL-NEXT: return attrNames; + +// DECL: ::mlir::Identifier aAttrAttrName() +// DECL-NEXT: return getAttributeNameForIndex(0); +// DECL: ::mlir::Identifier aAttrAttrName(::mlir::OperationName name) +// DECL-NEXT: return getAttributeNameForIndex(name, 0); + +// DECL: ::mlir::Identifier bAttrAttrName() +// DECL-NEXT: return getAttributeNameForIndex(1); +// DECL: ::mlir::Identifier bAttrAttrName(::mlir::OperationName name) +// DECL-NEXT: return getAttributeNameForIndex(name, 1); + +// DECL: ::mlir::Identifier cAttrAttrName() +// DECL-NEXT: return getAttributeNameForIndex(2); +// DECL: ::mlir::Identifier cAttrAttrName(::mlir::OperationName name) +// DECL-NEXT: return getAttributeNameForIndex(name, 2); + // DEF-LABEL: AOp definitions // Test verify method @@ -48,13 +73,13 @@ // --- // DEF: some-attr-kind AOp::aAttrAttr() -// DEF-NEXT: (*this)->getAttr("aAttr").template cast() +// DEF-NEXT: (*this)->getAttr(aAttrAttrName()).template cast() // DEF: some-return-type AOp::aAttr() { // DEF-NEXT: auto attr = aAttrAttr() // DEF-NEXT: return attr.some-convert-from-storage(); // DEF: some-attr-kind AOp::bAttrAttr() -// DEF-NEXT: return (*this)->getAttr("bAttr").template dyn_cast_or_null() +// DEF-NEXT: return (*this)->getAttr(bAttrAttrName()).template dyn_cast_or_null() // DEF: some-return-type AOp::bAttr() { // DEF-NEXT: auto attr = bAttrAttr(); // DEF-NEXT: if (!attr) @@ -62,7 +87,7 @@ // DEF-NEXT: return attr.some-convert-from-storage(); // DEF: some-attr-kind AOp::cAttrAttr() -// DEF-NEXT: return (*this)->getAttr("cAttr").template dyn_cast_or_null() +// DEF-NEXT: return (*this)->getAttr(cAttrAttrName()).template dyn_cast_or_null() // DEF: ::llvm::Optional AOp::cAttr() { // DEF-NEXT: auto attr = cAttrAttr() // DEF-NEXT: return attr ? ::llvm::Optional(attr.some-convert-from-storage()) : (::llvm::None); @@ -71,30 +96,30 @@ // --- // DEF: void AOp::aAttrAttr(some-attr-kind attr) { -// DEF-NEXT: (*this)->setAttr("aAttr", attr); +// DEF-NEXT: (*this)->setAttr(aAttrAttrName(), attr); // DEF: void AOp::bAttrAttr(some-attr-kind attr) { -// DEF-NEXT: (*this)->setAttr("bAttr", attr); +// DEF-NEXT: (*this)->setAttr(bAttrAttrName(), attr); // DEF: void AOp::cAttrAttr(some-attr-kind attr) { -// DEF-NEXT: (*this)->setAttr("cAttr", attr); +// DEF-NEXT: (*this)->setAttr(cAttrAttrName(), attr); // Test remove methods // --- // DEF: ::mlir::Attribute AOp::removeCAttrAttr() { -// DEF-NEXT: return (*this)->removeAttr("cAttr"); +// DEF-NEXT: return (*this)->removeAttr(cAttrAttrName()); // Test build methods // --- // DEF: void AOp::build( -// DEF: odsState.addAttribute("aAttr", aAttr); -// DEF: odsState.addAttribute("bAttr", bAttr); +// DEF: odsState.addAttribute(aAttrAttrName(odsState.name), aAttr); +// DEF: odsState.addAttribute(bAttrAttrName(odsState.name), bAttr); // DEF: if (cAttr) { -// DEF-NEXT: odsState.addAttribute("cAttr", cAttr); +// DEF-NEXT: odsState.addAttribute(cAttrAttrName(odsState.name), cAttr); // DEF: void AOp::build( // DEF: some-return-type aAttr, some-return-type bAttr, /*optional*/some-attr-kind cAttr -// DEF: odsState.addAttribute("aAttr", some-const-builder-call(odsBuilder, aAttr)); +// DEF: odsState.addAttribute(aAttrAttrName(odsState.name), some-const-builder-call(odsBuilder, aAttr)); // DEF: void AOp::build( // DEF: ::llvm::ArrayRef<::mlir::NamedAttribute> attributes @@ -205,8 +230,8 @@ // DECL: static void build({{.*}}, uint32_t i32_attr, ::llvm::APFloat f64_attr, ::llvm::StringRef str_attr, bool bool_attr, ::SomeI32Enum enum_attr, uint32_t dv_i32_attr, ::llvm::APFloat dv_f64_attr, ::llvm::StringRef dv_str_attr = "abc", bool dv_bool_attr = true, ::SomeI32Enum dv_enum_attr = ::SomeI32Enum::case5) // DEF-LABEL: DOp definitions -// DEF: odsState.addAttribute("str_attr", odsBuilder.getStringAttr(str_attr)); -// DEF: odsState.addAttribute("dv_str_attr", odsBuilder.getStringAttr(dv_str_attr)); +// DEF: odsState.addAttribute(str_attrAttrName(odsState.name), odsBuilder.getStringAttr(str_attr)); +// DEF: odsState.addAttribute(dv_str_attrAttrName(odsState.name), odsBuilder.getStringAttr(dv_str_attr)); // Test derived type attr. // --- @@ -272,7 +297,7 @@ // DEF: return {{.*}} != nullptr // DEF: ::mlir::Attribute UnitAttrOp::removeAttrAttr() { -// DEF-NEXT: (*this)->removeAttr("attr"); +// DEF-NEXT: (*this)->removeAttr(attrAttrName()); // DEF: build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, /*optional*/::mlir::UnitAttr attr) diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -21,6 +21,7 @@ #include "mlir/TableGen/Operator.h" #include "mlir/TableGen/SideEffects.h" #include "mlir/TableGen/Trait.h" +#include "llvm/ADT/MapVector.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/Path.h" @@ -79,7 +80,7 @@ auto sizeAttr = odsAttrs.get("{0}").cast<::mlir::DenseIntElementsAttr>(); )"; const char *opSegmentSizeAttrInitCode = R"( - auto sizeAttr = (*this)->getAttr("{0}").cast<::mlir::DenseIntElementsAttr>(); + auto sizeAttr = (*this)->getAttr({0}).cast<::mlir::DenseIntElementsAttr>(); )"; const char *attrSizedSegmentValueRangeCalcCode = R"( auto sizeAttrValues = sizeAttr.getValues(); @@ -306,6 +307,13 @@ void emitDecl(raw_ostream &os); void emitDef(raw_ostream &os); + // Generate methods for accessing the attribute names of this operation. + void genAttrNameGetters(); + + // Return the index of the given attribute name. This is a relative ordering + // for this name, used in attribute getters. + unsigned getAttrNameIndex(StringRef attrName) const; + // Generates the OpAsmOpInterface for this operation if possible. void genOpAsmInterface(); @@ -460,6 +468,10 @@ // The emitter containing all of the locally emitted verification functions. const StaticVerifierFunctionEmitter &staticVerifierEmitter; + + // A map of attribute names (including implicit attributes) registered to the + // current operation, to the relative order in which they were registered. + llvm::MapVector attributeNames; }; } // end anonymous namespace @@ -584,6 +596,7 @@ // Generate C++ code for various op methods. The order here determines the // methods in the generated file. + genAttrNameGetters(); genOpAsmInterface(); genOpNameGetter(); genNamedOperandGetters(); @@ -622,18 +635,103 @@ void OpEmitter::emitDef(raw_ostream &os) { opClass.writeDefTo(os); } +void OpEmitter::genAttrNameGetters() { + // Enumerate the attribute names of this op, assigning each a relative + // ordering. + auto addAttrName = [&](StringRef name) { + unsigned index = attributeNames.size(); + attributeNames.insert({name, index}); + }; + for (const NamedAttribute &namedAttr : op.getAttributes()) + addAttrName(namedAttr.name); + // Include key attributes from several traits as implicitly registered. + if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) + addAttrName("operand_segment_sizes"); + if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) + addAttrName("result_segment_sizes"); + + // Emit the getAttributeNames method. + { + auto *method = opClass.addMethodAndPrune( + "::llvm::ArrayRef<::llvm::StringRef>", "getAttributeNames", + OpMethod::Property(OpMethod::MP_Static | OpMethod::MP_Inline)); + auto &body = method->body(); + if (attributeNames.empty()) { + body << " return {};"; + } else { + body << " static ::llvm::StringRef attrNames[] = {"; + llvm::interleaveComma(llvm::make_first_range(attributeNames), body, + [&](StringRef attrName) { + body << "::llvm::StringRef(\"" << attrName + << "\")"; + }); + body << "};\n return attrNames;"; + } + } + + // Emit the getAttributeNameForIndex methods. + { + auto *method = opClass.addMethodAndPrune( + "::mlir::Identifier", "getAttributeNameForIndex", + OpMethod::Property(OpMethod::MP_Private | OpMethod::MP_Inline), + "unsigned", "index"); + method->body() + << " return getAttributeNameForIndex((*this)->getName(), index);"; + } + { + auto *method = opClass.addMethodAndPrune( + "::mlir::Identifier", "getAttributeNameForIndex", + OpMethod::Property(OpMethod::MP_Private | OpMethod::MP_Inline | + OpMethod::MP_Static), + "::mlir::OperationName name, unsigned index"); + method->body() << "assert(index < " << attributeNames.size() + << " && \"invalid attribute index\");\n" + " return name.getAbstractOperation()" + "->getAttributeNames()[index];"; + } + + // Generate the AttrName methods, that expose the attribute names to + // users. + const char *attrNameMethodBody = " return getAttributeNameForIndex({0});"; + for (const std::pair &attrIt : attributeNames) { + std::string methodName = (attrIt.first + "AttrName").str(); + + // Generate the non-static variant. + { + auto *method = + opClass.addMethodAndPrune("::mlir::Identifier", methodName, + OpMethod::Property(OpMethod::MP_Inline)); + method->body() << llvm::formatv(attrNameMethodBody, attrIt.second).str(); + } + + // Generate the static variant. + { + auto *method = opClass.addMethodAndPrune( + "::mlir::Identifier", methodName, + OpMethod::Property(OpMethod::MP_Inline | OpMethod::MP_Static), + "::mlir::OperationName", "name"); + method->body() << llvm::formatv(attrNameMethodBody, + "name, " + Twine(attrIt.second)) + .str(); + } + } +} + +unsigned OpEmitter::getAttrNameIndex(StringRef attrName) const { + auto it = attributeNames.find(attrName); + assert(it != attributeNames.end() && "expected attribute name to have been " + "registered in genAttrNameGetters"); + return it->second; +} + void OpEmitter::genAttrGetters() { FmtContext fctx; fctx.withBuilder("::mlir::Builder((*this)->getContext())"); - Dialect opDialect = op.getDialect(); // Emit the derived attribute body. auto emitDerivedAttr = [&](StringRef name, Attribute attr) { - auto *method = opClass.addMethodAndPrune(attr.getReturnType(), name); - if (!method) - return; - auto &body = method->body(); - body << " " << attr.getDerivedCodeBody() << "\n"; + if (auto *method = opClass.addMethodAndPrune(attr.getReturnType(), name)) + method->body() << " " << attr.getDerivedCodeBody() << "\n"; }; // Emit with return type specified. @@ -666,7 +764,7 @@ if (!method) return; auto &body = method->body(); - body << " return (*this)->getAttr(\"" << name << "\").template "; + body << " return (*this)->getAttr(" << name << "AttrName()).template "; if (attr.isOptional() || attr.hasDefaultValue()) body << "dyn_cast_or_null<"; else @@ -674,14 +772,12 @@ body << attr.getStorageType() << ">();"; }; - for (auto &namedAttr : op.getAttributes()) { - const auto &name = namedAttr.name; - const auto &attr = namedAttr.attr; - if (attr.isDerivedAttr()) { - emitDerivedAttr(name, attr); + for (const NamedAttribute &namedAttr : op.getAttributes()) { + if (namedAttr.attr.isDerivedAttr()) { + emitDerivedAttr(namedAttr.name, namedAttr.attr); } else { - emitAttrWithStorageType(name, attr); - emitAttrWithReturnType(name, attr); + emitAttrWithStorageType(namedAttr.name, namedAttr.attr); + emitAttrWithReturnType(namedAttr.name, namedAttr.attr); } } @@ -738,8 +834,7 @@ derivedAttrs, body, [&](const NamedAttribute &namedAttr) { auto tmpl = namedAttr.attr.getConvertFromStorageCall(); - body << " {::mlir::Identifier::get(\"" << namedAttr.name - << "\", ctx),\n" + body << " {" << namedAttr.name << "AttrName(),\n" << tgfmt(tmpl, &fctx.withSelf(namedAttr.name + "()") .withBuilder("odsBuilder") .addSubst("_ctx", "ctx")) @@ -758,18 +853,13 @@ auto emitAttrWithStorageType = [&](StringRef name, Attribute attr) { auto *method = opClass.addMethodAndPrune("void", (name + "Attr").str(), attr.getStorageType(), "attr"); - if (!method) - return; - auto &body = method->body(); - body << " (*this)->setAttr(\"" << name << "\", attr);"; + if (method) + method->body() << " (*this)->setAttr(" << name << "AttrName(), attr);"; }; - for (auto &namedAttr : op.getAttributes()) { - const auto &name = namedAttr.name; - const auto &attr = namedAttr.attr; - if (!attr.isDerivedAttr()) - emitAttrWithStorageType(name, attr); - } + for (const NamedAttribute &namedAttr : op.getAttributes()) + if (!namedAttr.attr.isDerivedAttr()) + emitAttrWithStorageType(namedAttr.name, namedAttr.attr); } void OpEmitter::genOptionalAttrRemovers() { @@ -782,16 +872,12 @@ "::mlir::Attribute", ("remove" + upperInitial + suffix + "Attr").str()); if (!method) return; - auto &body = method->body(); - body << " return (*this)->removeAttr(\"" << name << "\");"; + method->body() << " return (*this)->removeAttr(" << name << "AttrName());"; }; - for (const auto &namedAttr : op.getAttributes()) { - const auto &name = namedAttr.name; - const auto &attr = namedAttr.attr; - if (attr.isOptional()) - emitRemoveAttr(name); - } + for (const NamedAttribute &namedAttr : op.getAttributes()) + if (namedAttr.attr.isOptional()) + emitRemoveAttr(namedAttr.name); } // Generates the code to compute the start and end index of an operand or result @@ -903,10 +989,18 @@ } void OpEmitter::genNamedOperandGetters() { + // Build the code snippet used for initializing the operand_segment_sizes + // array. + std::string attrSizeInitCode; + if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { + attrSizeInitCode = + formatv(opSegmentSizeAttrInitCode, "operand_segment_sizesAttrName()") + .str(); + } + generateNamedOperandGetters( op, opClass, - /*sizeAttrInit=*/ - formatv(opSegmentSizeAttrInitCode, "operand_segment_sizes").str(), + /*sizeAttrInit=*/attrSizeInitCode, /*rangeType=*/"::mlir::Operation::operand_range", /*rangeBeginCall=*/"getOperation()->operand_begin()", /*rangeSizeCall=*/"getOperation()->getNumOperands()", @@ -929,7 +1023,7 @@ if (attrSizedOperands) body << ", ::mlir::MutableOperandRange::OperandSegment(" << i << "u, *getOperation()->getAttrDictionary().getNamed(" - "\"operand_segment_sizes\"))"; + "operand_segment_sizesAttrName()))"; body << ");\n"; } } @@ -963,11 +1057,18 @@ "'SameVariadicResultSize' traits"); } + // Build the initializer string for the result segment size attribute. + std::string attrSizeInitCode; + if (attrSizedResults) { + attrSizeInitCode = + formatv(opSegmentSizeAttrInitCode, "result_segment_sizesAttrName()") + .str(); + } + generateValueRangeStartAndEnd( opClass, "getODSResultIndexAndLength", numVariadicResults, numNormalResults, "getOperation()->getNumResults()", attrSizedResults, - formatv(opSegmentSizeAttrInitCode, "result_segment_sizes").str(), - op.getResults()); + attrSizeInitCode, op.getResults()); auto *m = opClass.addMethodAndPrune("::mlir::Operation::result_range", "getODSResults", "unsigned", "index"); @@ -1297,8 +1398,11 @@ std::string resultType; const auto &namedAttr = op.getAttribute(0); - body << " for (auto attr : attributes) {\n"; - body << " if (attr.first != \"" << namedAttr.name << "\") continue;\n"; + body << " auto attrName = " << namedAttr.name << "AttrName(" + << builderOpState + << ".name);\n" + " for (auto attr : attributes) {\n" + " if (attr.first != attrName) continue;\n"; if (namedAttr.attr.isTypeAttr()) { resultType = "attr.second.cast<::mlir::TypeAttr>().getValue()"; } else { @@ -1595,8 +1699,9 @@ // If the operation has the operand segment size attribute, add it here. if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { body << " " << builderOpState - << ".addAttribute(\"operand_segment_sizes\", " - "odsBuilder.getI32VectorAttr({"; + << ".addAttribute(operand_segment_sizesAttrName(" << builderOpState + << ".name), " + << "odsBuilder.getI32VectorAttr({"; interleaveComma(llvm::seq(0, op.getNumOperands()), body, [&](int i) { if (op.getOperand(i).isOptional()) body << "(" << getArgumentName(op, i) << " ? 1 : 0)"; @@ -1613,9 +1718,9 @@ auto &attr = namedAttr.attr; if (!attr.isDerivedAttr()) { bool emitNotNullCheck = attr.isOptional(); - if (emitNotNullCheck) { + if (emitNotNullCheck) body << formatv(" if ({0}) ", namedAttr.name) << "{\n"; - } + if (isRawValueAttr && canUseUnwrappedRawValue(attr)) { // If this is a raw value, then we need to wrap it in an Attribute // instance. @@ -1634,15 +1739,14 @@ std::string value = std::string(tgfmt(builderTemplate, &fctx, namedAttr.name)); - body << formatv(" {0}.addAttribute(\"{1}\", {2});\n", builderOpState, - namedAttr.name, value); + body << formatv(" {0}.addAttribute({1}AttrName({0}.name), {2});\n", + builderOpState, namedAttr.name, value); } else { - body << formatv(" {0}.addAttribute(\"{1}\", {1});\n", builderOpState, - namedAttr.name); + body << formatv(" {0}.addAttribute({1}AttrName({0}.name), {1});\n", + builderOpState, namedAttr.name); } - if (emitNotNullCheck) { + if (emitNotNullCheck) body << " }\n"; - } } } diff --git a/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp b/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp --- a/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp +++ b/mlir/unittests/Interfaces/DataLayoutInterfacesTest.cpp @@ -114,6 +114,7 @@ /// types itself. struct OpWithLayout : public Op { using Op::Op; + static ArrayRef getAttributeNames() { return {}; } static StringRef getOperationName() { return "dltest.op_with_layout"; } @@ -156,6 +157,7 @@ struct OpWith7BitByte : public Op { using Op::Op; + static ArrayRef getAttributeNames() { return {}; } static StringRef getOperationName() { return "dltest.op_with_7bit_byte"; }