diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -2836,20 +2836,24 @@ // 'Parameters' should be subclasses of this or simple strings (which is a // shorthand for AttrOrTypeParameter<"C++Type">). -class AttrOrTypeParameter { +class AttrOrTypeParameter { // Custom memory allocation code for storage constructor. code allocator = ?; // Custom comparator used to compare two instances for equality. code comparator = ?; // The C++ type of this parameter. string cppType = type; + // The C++ type of the accessor for this parameter. + string cppAccessorType = !if(!empty(accessorType), type, accessorType); // One-line human-readable description of the argument. string summary = desc; // The format string for the asm syntax (documentation only). string syntax = ?; } -class AttrParameter : AttrOrTypeParameter; -class TypeParameter : AttrOrTypeParameter; +class AttrParameter + : AttrOrTypeParameter; +class TypeParameter + : AttrOrTypeParameter; // For StringRefs, which require allocation. class StringRefParameter : diff --git a/mlir/include/mlir/TableGen/AttrOrTypeDef.h b/mlir/include/mlir/TableGen/AttrOrTypeDef.h --- a/mlir/include/mlir/TableGen/AttrOrTypeDef.h +++ b/mlir/include/mlir/TableGen/AttrOrTypeDef.h @@ -196,6 +196,9 @@ // Get the C++ type of this parameter. StringRef getCppType() const; + // Get the C++ accessor type of this parameter. + StringRef getCppAccessorType() const; + // Get a description of this parameter for documentation purposes. Optional getSummary() const; diff --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp --- a/mlir/lib/TableGen/AttrOrTypeDef.cpp +++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp @@ -210,6 +210,15 @@ "which inherit from AttrOrTypeParameter\n"); } +StringRef AttrOrTypeParameter::getCppAccessorType() const { + if (auto *param = dyn_cast(def->getArg(index))) { + if (Optional type = + param->getDef()->getValueAsOptionalString("cppAccessorType")) + return *type; + } + return getCppType(); +} + Optional AttrOrTypeParameter::getSummary() const { auto *parameterType = def->getArg(index); if (auto *param = dyn_cast(parameterType)) { diff --git a/mlir/test/mlir-tblgen/attrdefs.td b/mlir/test/mlir-tblgen/attrdefs.td --- a/mlir/test/mlir-tblgen/attrdefs.td +++ b/mlir/test/mlir-tblgen/attrdefs.td @@ -135,3 +135,14 @@ // DEF-LABEL: struct AttrWithTypeBuilderAttrStorage // DEF: AttrWithTypeBuilderAttrStorage (::mlir::IntegerAttr attr) // DEF-NEXT: : ::mlir::AttributeStorage(attr.getType()), attr(attr) + +def F_ParamWithAccessorTypeAttr : TestAttr<"ParamWithAccessorType"> { + let parameters = (ins AttrParameter<"std::string", "", "StringRef">:$param); +} + +// DECL-LABEL: class ParamWithAccessorTypeAttr +// DECL: StringRef getParam() +// DEF: ParamWithAccessorTypeAttrStorage +// DEF-NEXT: ParamWithAccessorTypeAttrStorage (std::string param) +// DEF: StringRef ParamWithAccessorTypeAttr::getParam() + diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp @@ -413,7 +413,8 @@ for (AttrOrTypeParameter ¶meter : parameters) { SmallString<16> name = parameter.getName(); name[0] = llvm::toUpper(name[0]); - os << formatv(" {0} get{1}() const;\n", parameter.getCppType(), name); + os << formatv(" {0} get{1}() const;\n", parameter.getCppAccessorType(), + name); } } @@ -859,7 +860,7 @@ SmallString<16> name = param.getName(); name[0] = llvm::toUpper(name[0]); os << formatv("{0} {3}::get{1}() const {{ return getImpl()->{2}; }\n", - param.getCppType(), name, paramStorageName, + param.getCppAccessorType(), name, paramStorageName, def.getCppClassName()); } }