diff --git a/mlir/test/mlir-tblgen/op-attribute.td b/mlir/test/mlir-tblgen/op-attribute.td --- a/mlir/test/mlir-tblgen/op-attribute.td +++ b/mlir/test/mlir-tblgen/op-attribute.td @@ -201,6 +201,17 @@ // DEF: odsState.addAttribute("str_attr", (*odsBuilder).getStringAttr(str_attr)); // DEF: odsState.addAttribute("dv_str_attr", (*odsBuilder).getStringAttr(dv_str_attr)); +// Test derived type attr. +// --- +def DerivedTypeAttrOp : NS_Op<"derived_type_attr_op", []> { + let results = (outs AnyTensor:$output); + DerivedTypeAttr element_dtype = DerivedTypeAttr<"return output().getType();">; +} + +// DECL: bool isDerivedAttribute +// DEF: bool DerivedTypeAttrOp::isDerivedAttribute(StringRef name) { +// DEF: return llvm::is_contained(llvm::makeArrayRef({"element_dtype"})); +// DEF: } // Test that only default valued attributes at the end of the arguments // list get default values in the builder signature diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -391,6 +391,27 @@ emitAttrWithReturnType(name, attr); } } + + // Generate helper method to query whether a named attribute is a derived + // attribute. This enables, for example, avoiding adding an attribute that + // overlaps with a derived attribute. + auto &method = + opClass.newMethod("bool", "isDerivedAttribute", "StringRef name"); + auto &body = method.body(); + auto derivedAttr = make_filter_range(op.getAttributes(), + [](const NamedAttribute &namedAttr) { + return namedAttr.attr.isDerivedAttr(); + }); + if (derivedAttr.empty()) { + body << " return false;"; + } else { + body << " return llvm::is_contained(llvm::makeArrayRef({"; + mlir::interleaveComma(derivedAttr, body, + [&](const NamedAttribute &namedAttr) { + body << "\"" << namedAttr.name << "\""; + }); + body << "}));"; + } } void OpEmitter::genAttrSetters() {