diff --git a/mlir/test/mlir-tblgen/op-attribute.td b/mlir/test/mlir-tblgen/op-attribute.td --- a/mlir/test/mlir-tblgen/op-attribute.td +++ b/mlir/test/mlir-tblgen/op-attribute.td @@ -77,6 +77,12 @@ // DEF: void AOp::cAttrAttr(some-attr-kind attr) { // DEF-NEXT: (*this)->setAttr("cAttr", attr); +// Test remove methods +// --- + +// DEF: Attribute AOp::removeCAttrAttr() { +// DEF-NEXT: return (*this)->removeAttr("cAttr"); + // Test build methods // --- @@ -265,6 +271,9 @@ // DEF: bool UnitAttrOp::attr() { // DEF: return {{.*}} != nullptr +// DEF: Attribute UnitAttrOp::removeAttrAttr() { +// DEF-NEXT: (*this)->removeAttr("attr"); + // DEF: build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, /*optional*/::mlir::UnitAttr attr) diff --git a/mlir/test/mlir-tblgen/op-decl.td b/mlir/test/mlir-tblgen/op-decl.td --- a/mlir/test/mlir-tblgen/op-decl.td +++ b/mlir/test/mlir-tblgen/op-decl.td @@ -81,6 +81,7 @@ // CHECK: uint32_t attr1(); // CHECK: ::mlir::FloatAttr attr2Attr() // CHECK: ::llvm::Optional< ::llvm::APFloat > attr2(); +// CHECK: Attribute removeAttr2Attr(); // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, Value val); // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, int integer = 0); // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::mlir::TypeRange s, ::mlir::Value a, ::mlir::ValueRange b, ::mlir::IntegerAttr attr1, /*optional*/::mlir::FloatAttr attr2, unsigned someRegionsCount) diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -328,6 +328,9 @@ // Generates setter for the attributes. void genAttrSetters(); + // Generates removers for optional attributes. + void genOptionalAttrRemovers(); + // Generates getters for named operands. void genNamedOperandGetters(); @@ -600,6 +603,7 @@ genNamedSuccessorGetters(); genAttrGetters(); genAttrSetters(); + genOptionalAttrRemovers(); genBuilder(); genParser(); genPrinter(); @@ -777,6 +781,28 @@ } } +void OpEmitter::genOptionalAttrRemovers() { + // Generate methods for removing optional attributes, instead of having to + // use the string interface. Enables better compile time verification. + auto emitRemoveAttr = [&](StringRef name) { + auto upperInitial = name.take_front().upper(); + auto suffix = name.drop_front(); + auto *method = opClass.addMethodAndPrune( + "Attribute", ("remove" + upperInitial + suffix + "Attr").str()); + if (!method) + return; + auto &body = method->body(); + body << " return (*this)->removeAttr(\"" << name << "\");"; + }; + + for (const auto &namedAttr : op.getAttributes()) { + const auto &name = namedAttr.name; + const auto &attr = namedAttr.attr; + if (attr.isOptional()) + emitRemoveAttr(name); + } +} + // Generates the code to compute the start and end index of an operand or result // range. template