diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -21,6 +21,7 @@ #include "mlir/IR/Value.h" #include "mlir/Support/LogicalResult.h" #include "llvm/ADT/PointerUnion.h" +#include "llvm/ADT/StringSwitch.h" #include "llvm/Support/TrailingObjects.h" #include diff --git a/mlir/test/mlir-tblgen/op-attribute.td b/mlir/test/mlir-tblgen/op-attribute.td --- a/mlir/test/mlir-tblgen/op-attribute.td +++ b/mlir/test/mlir-tblgen/op-attribute.td @@ -201,6 +201,19 @@ // DEF: odsState.addAttribute("str_attr", (*odsBuilder).getStringAttr(str_attr)); // DEF: odsState.addAttribute("dv_str_attr", (*odsBuilder).getStringAttr(dv_str_attr)); +// Test derived type attr. +// --- +def DerivedTypeAttrOp : NS_Op<"derived_type_attr_op", []> { + let results = (outs AnyTensor:$output); + DerivedTypeAttr element_dtype = DerivedTypeAttr<"return output().getType();">; +} + +// DECL: bool isDerivedAttribute +// DEF: bool DerivedTypeAttrOp::isDerivedAttribute(StringRef name) { +// DEF: return llvm::StringSwitch(name) +// DEF: .Case("element_dtype", true) +// DEF: .Default(false); +// DEF: } // Test that only default valued attributes at the end of the arguments // list get default values in the builder signature diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -391,6 +391,18 @@ emitAttrWithReturnType(name, attr); } } + + // Generate helper method to query whether a named attribute is a derived + // attribute. This enables, for example, avoiding adding an attribute that + // overlaps with a derived attribute. + auto &method = + opClass.newMethod("bool", "isDerivedAttribute", "StringRef name"); + auto &body = method.body(); + body << " return llvm::StringSwitch(name)"; + for (auto namedAttr : op.getAttributes()) + if (namedAttr.attr.isDerivedAttr()) + body << "\n .Case(\"" << namedAttr.name << "\", true)"; + body << "\n .Default(false);"; } void OpEmitter::genAttrSetters() {