diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td @@ -39,6 +39,9 @@ let hasConstantMaterializer = 1; let hasOperationAttrVerify = 1; + + // TODO: Example for testing/example, submit as follow on change instead. + let emitAccessorPrefix = kEmitAccessorPrefix_Both; } def Shape_ShapeType : DialectType dependentDialects; diff --git a/mlir/lib/TableGen/Dialect.cpp b/mlir/lib/TableGen/Dialect.cpp --- a/mlir/lib/TableGen/Dialect.cpp +++ b/mlir/lib/TableGen/Dialect.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/TableGen/Dialect.h" +#include "llvm/TableGen/Error.h" #include "llvm/TableGen/Record.h" using namespace mlir; @@ -89,6 +90,13 @@ return def->getValueAsBit("hasOperationInterfaceFallback"); } +Dialect::EmitPrefix Dialect::getEmitAccessorPrefix() const { + int prefix = def->getValueAsInt("emitAccessorPrefix"); + if (prefix < 0 || prefix > static_cast(EmitPrefix::Both)) + PrintFatalError(def->getLoc(), "Invalid accessor prefix value"); + return static_cast(prefix); +} + bool Dialect::operator==(const Dialect &other) const { return def == other.def; } diff --git a/mlir/test/mlir-tblgen/op-attribute.td b/mlir/test/mlir-tblgen/op-attribute.td --- a/mlir/test/mlir-tblgen/op-attribute.td +++ b/mlir/test/mlir-tblgen/op-attribute.td @@ -130,6 +130,118 @@ // DEF: ::llvm::ArrayRef<::mlir::NamedAttribute> attributes // DEF: odsState.addAttributes(attributes); +// Test the above but with prefix. + +def Test2_Dialect : Dialect { + let name = "test2"; + let cppNamespace = "foobar2"; + let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed; +} +def AgetOp : Op { + let arguments = (ins + SomeAttr:$aAttr, + DefaultValuedAttr:$bAttr, + OptionalAttr:$cAttr + ); +} + +// DECL-LABEL: AgetOp declarations + +// Test attribute name methods +// --- + +// DECL: static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() +// DECL-NEXT: static ::llvm::StringRef attrNames[] = +// DECL-SAME: {::llvm::StringRef("aAttr"), ::llvm::StringRef("bAttr"), ::llvm::StringRef("cAttr")}; +// DECL-NEXT: return ::llvm::makeArrayRef(attrNames); + +// DECL: ::mlir::Identifier getAAttrAttrName() +// DECL-NEXT: return getAttributeNameForIndex(0); +// DECL: ::mlir::Identifier getAAttrAttrName(::mlir::OperationName name) +// DECL-NEXT: return getAttributeNameForIndex(name, 0); + +// DECL: ::mlir::Identifier getBAttrAttrName() +// DECL-NEXT: return getAttributeNameForIndex(1); +// DECL: ::mlir::Identifier getBAttrAttrName(::mlir::OperationName name) +// DECL-NEXT: return getAttributeNameForIndex(name, 1); + +// DECL: ::mlir::Identifier getCAttrAttrName() +// DECL-NEXT: return getAttributeNameForIndex(2); +// DECL: ::mlir::Identifier getCAttrAttrName(::mlir::OperationName name) +// DECL-NEXT: return getAttributeNameForIndex(name, 2); + +// DEF-LABEL: AgetOp definitions + +// Test verify method +// --- + +// DEF: ::mlir::LogicalResult AgetOpAdaptor::verify +// DEF: auto tblgen_aAttr = odsAttrs.get("aAttr"); +// DEF-NEXT: if (!tblgen_aAttr) return emitError(loc, "'test2.a_get_op' op ""requires attribute 'aAttr'"); +// DEF: if (!((some-condition))) return emitError(loc, "'test2.a_get_op' op ""attribute 'aAttr' failed to satisfy constraint: some attribute kind"); +// DEF: auto tblgen_bAttr = odsAttrs.get("bAttr"); +// DEF-NEXT: if (tblgen_bAttr) { +// DEF-NEXT: if (!((some-condition))) return emitError(loc, "'test2.a_get_op' op ""attribute 'bAttr' failed to satisfy constraint: some attribute kind"); +// DEF: auto tblgen_cAttr = odsAttrs.get("cAttr"); +// DEF-NEXT: if (tblgen_cAttr) { +// DEF-NEXT: if (!((some-condition))) return emitError(loc, "'test2.a_get_op' op ""attribute 'cAttr' failed to satisfy constraint: some attribute kind"); + +// Test getter methods +// --- + +// DEF: some-attr-kind AgetOp::getAAttrAttr() +// DEF-NEXT: (*this)->getAttr(getAAttrAttrName()).template cast() +// DEF: some-return-type AgetOp::getAAttr() { +// DEF-NEXT: auto attr = getAAttrAttr() +// DEF-NEXT: return attr.some-convert-from-storage(); + +// DEF: some-attr-kind AgetOp::getBAttrAttr() +// DEF-NEXT: return (*this)->getAttr(getBAttrAttrName()).template dyn_cast_or_null() +// DEF: some-return-type AgetOp::getBAttr() { +// DEF-NEXT: auto attr = getBAttrAttr(); +// DEF-NEXT: if (!attr) +// DEF-NEXT: return some-const-builder-call(::mlir::Builder((*this)->getContext()), 4.2).some-convert-from-storage(); +// DEF-NEXT: return attr.some-convert-from-storage(); + +// DEF: some-attr-kind AgetOp::getCAttrAttr() +// DEF-NEXT: return (*this)->getAttr(getCAttrAttrName()).template dyn_cast_or_null() +// DEF: ::llvm::Optional AgetOp::getCAttr() { +// DEF-NEXT: auto attr = getCAttrAttr() +// DEF-NEXT: return attr ? ::llvm::Optional(attr.some-convert-from-storage()) : (::llvm::None); + +// Test setter methods +// --- + +// DEF: void AgetOp::setAAttrAttr(some-attr-kind attr) { +// DEF-NEXT: (*this)->setAttr(getAAttrAttrName(), attr); +// DEF: void AgetOp::setBAttrAttr(some-attr-kind attr) { +// DEF-NEXT: (*this)->setAttr(getBAttrAttrName(), attr); +// DEF: void AgetOp::setCAttrAttr(some-attr-kind attr) { +// DEF-NEXT: (*this)->setAttr(getCAttrAttrName(), attr); + +// Test remove methods +// --- + +// DEF: ::mlir::Attribute AgetOp::removeCAttrAttr() { +// DEF-NEXT: return (*this)->removeAttr(cAttrAttrName()); + +// Test build methods +// --- + +// DEF: void AgetOp::build( +// DEF: odsState.addAttribute(aAttrAttrName(odsState.name), aAttr); +// DEF: odsState.addAttribute(bAttrAttrName(odsState.name), bAttr); +// DEF: if (cAttr) { +// DEF-NEXT: odsState.addAttribute(cAttrAttrName(odsState.name), cAttr); + +// DEF: void AgetOp::build( +// DEF: some-return-type aAttr, some-return-type bAttr, /*optional*/some-attr-kind cAttr +// DEF: odsState.addAttribute(aAttrAttrName(odsState.name), some-const-builder-call(odsBuilder, aAttr)); + +// DEF: void AgetOp::build( +// DEF: ::llvm::ArrayRef<::mlir::NamedAttribute> attributes +// DEF: odsState.addAttributes(attributes); + def SomeTypeAttr : TypeAttrBase<"SomeType", "some type attribute">; def BOp : NS_Op<"b_op", []> { diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -25,6 +25,7 @@ #include "llvm/ADT/Sequence.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringSet.h" +#include "llvm/Support/Debug.h" #include "llvm/Support/Signals.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Record.h" @@ -524,6 +525,69 @@ void OpEmitter::emitDef(raw_ostream &os) { opClass.writeDefTo(os); } +// Helper to return the names for accessor. +template +static SmallVector getNames(const Operator &op, + StringRef name) { + auto prefixType = op.getDialect().getEmitAccessorPrefix(); + std::string prefix; + if (prefixType != Dialect::EmitPrefix::Raw) + prefix = read ? "get" : "set"; + + SmallVector names; + bool rawToo = prefixType == Dialect::EmitPrefix::Both; + + auto skip = [&](StringRef newName) { + bool shouldSkip = newName == "getOperands"; + if (!shouldSkip) + return false; + + // This note could be avoided where the final function generated would + // have been identical. But preferably in the op definition avoiding using + // the generic name and then getting a more specialize type is better. + PrintNote(op.getLoc(), + "Skipping generation of prefixed accessor `" + newName + + "` as overlaps with default one; generating raw form (`" + + name + "`) still"); + return true; + }; + + if (!prefix.empty()) { + names.push_back(prefix + convertToCamelFromSnakeCase(name, true)); + // Skip cases which would overlap with default ones for now. + if (skip(names.back())) { + rawToo = true; + names.clear(); + } else { + LLVM_DEBUG(llvm::errs() << "WITH_GETTER(\"" << op.getQualCppClassName() + << "::" << names.back() << "\");\n";); + } + } + + if (prefix.empty() || rawToo) + names.push_back(name.str()); + return names; +} +static SmallVector getGetNames(const Operator &op, + StringRef name) { + return getNames(op, name); +} +static SmallVector getSetNames(const Operator &op, + StringRef name) { + return getNames(op, name); +} + +static void errorIfPruned(size_t line, OpMethod *m, const Twine &methodName, + const Operator &op) { + if (m) + return; + PrintFatalError(op.getLoc(), "Unexpected overlap when generating `" + + methodName + "` for " + + op.getOperationName() + " (from line " + + Twine(line) + ")"); +} +#define ERROR_IF_PRUNED(M, N, O) errorIfPruned(__LINE__, M, N, O) + void OpEmitter::genAttrNameGetters() { // Enumerate the attribute names of this op, assigning each a relative // ordering. @@ -544,6 +608,7 @@ auto *method = opClass.addMethodAndPrune( "::llvm::ArrayRef<::llvm::StringRef>", "getAttributeNames", OpMethod::Property(OpMethod::MP_Static | OpMethod::MP_Inline)); + ERROR_IF_PRUNED(method, "getAttributeNames", op); auto &body = method->body(); if (attributeNames.empty()) { body << " return {};"; @@ -566,6 +631,7 @@ "::mlir::Identifier", "getAttributeNameForIndex", OpMethod::Property(OpMethod::MP_Private | OpMethod::MP_Inline), "unsigned", "index"); + ERROR_IF_PRUNED(method, "getAttributeNameForIndex", op); method->body() << " return getAttributeNameForIndex((*this)->getName(), index);"; } @@ -575,6 +641,7 @@ OpMethod::Property(OpMethod::MP_Private | OpMethod::MP_Inline | OpMethod::MP_Static), "::mlir::OperationName name, unsigned index"); + ERROR_IF_PRUNED(method, "getAttributeNameForIndex", op); method->body() << "assert(index < " << attributeNames.size() << " && \"invalid attribute index\");\n" " return name.getAbstractOperation()" @@ -585,25 +652,30 @@ // users. const char *attrNameMethodBody = " return getAttributeNameForIndex({0});"; for (const std::pair &attrIt : attributeNames) { - std::string methodName = (attrIt.first + "AttrName").str(); + for (StringRef name : getGetNames(op, attrIt.first)) { + std::string methodName = (name + "AttrName").str(); - // Generate the non-static variant. - { - auto *method = - opClass.addMethodAndPrune("::mlir::Identifier", methodName, - OpMethod::Property(OpMethod::MP_Inline)); - method->body() << llvm::formatv(attrNameMethodBody, attrIt.second).str(); - } + // Generate the non-static variant. + { + auto *method = + opClass.addMethodAndPrune("::mlir::Identifier", methodName, + OpMethod::Property(OpMethod::MP_Inline)); + ERROR_IF_PRUNED(method, methodName, op); + method->body() + << llvm::formatv(attrNameMethodBody, attrIt.second).str(); + } - // Generate the static variant. - { - auto *method = opClass.addMethodAndPrune( - "::mlir::Identifier", methodName, - OpMethod::Property(OpMethod::MP_Inline | OpMethod::MP_Static), - "::mlir::OperationName", "name"); - method->body() << llvm::formatv(attrNameMethodBody, - "name, " + Twine(attrIt.second)) - .str(); + // Generate the static variant. + { + auto *method = opClass.addMethodAndPrune( + "::mlir::Identifier", methodName, + OpMethod::Property(OpMethod::MP_Inline | OpMethod::MP_Static), + "::mlir::OperationName", "name"); + ERROR_IF_PRUNED(method, methodName, op); + method->body() << llvm::formatv(attrNameMethodBody, + "name, " + Twine(attrIt.second)) + .str(); + } } } } @@ -621,6 +693,7 @@ // Emit with return type specified. auto emitAttrWithReturnType = [&](StringRef name, Attribute attr) { auto *method = opClass.addMethodAndPrune(attr.getReturnType(), name); + ERROR_IF_PRUNED(method, name, op); auto &body = method->body(); body << " auto attr = " << name << "Attr();\n"; if (attr.hasDefaultValue()) { @@ -639,9 +712,9 @@ << ";\n"; }; - // Generate raw named accessor type. This is a wrapper class that allows - // referring to the attributes via accessors instead of having to use - // the string interface for better compile time verification. + // Generate named accessor with Attribute return type. This is a wrapper class + // that allows referring to the attributes via accessors instead of having to + // use the string interface for better compile time verification. auto emitAttrWithStorageType = [&](StringRef name, Attribute attr) { auto *method = opClass.addMethodAndPrune(attr.getStorageType(), (name + "Attr").str()); @@ -657,11 +730,13 @@ }; for (const NamedAttribute &namedAttr : op.getAttributes()) { - if (namedAttr.attr.isDerivedAttr()) { - emitDerivedAttr(namedAttr.name, namedAttr.attr); - } else { - emitAttrWithStorageType(namedAttr.name, namedAttr.attr); - emitAttrWithReturnType(namedAttr.name, namedAttr.attr); + for (StringRef name : getGetNames(op, namedAttr.name)) { + if (namedAttr.attr.isDerivedAttr()) { + emitDerivedAttr(name, namedAttr.attr); + } else { + emitAttrWithStorageType(name, namedAttr.attr); + emitAttrWithReturnType(name, namedAttr.attr); + } } } @@ -678,6 +753,7 @@ auto *method = opClass.addMethodAndPrune("bool", "isDerivedAttribute", OpMethod::MP_Static, "::llvm::StringRef", "name"); + ERROR_IF_PRUNED(method, "isDerivedAttribute", op); auto &body = method->body(); for (auto namedAttr : derivedAttrs) body << " if (name == \"" << namedAttr.name << "\") return true;\n"; @@ -687,6 +763,7 @@ { auto *method = opClass.addMethodAndPrune("::mlir::DictionaryAttr", "materializeDerivedAttributes"); + ERROR_IF_PRUNED(method, "materializeDerivedAttributes", op); auto &body = method->body(); auto nonMaterializable = @@ -734,16 +811,22 @@ // Generate raw named setter type. This is a wrapper class that allows setting // to the attributes via setters instead of having to use the string interface // for better compile time verification. - auto emitAttrWithStorageType = [&](StringRef name, Attribute attr) { - auto *method = opClass.addMethodAndPrune("void", (name + "Attr").str(), + auto emitAttrWithStorageType = [&](StringRef setName, StringRef getName, + Attribute attr) { + auto *method = opClass.addMethodAndPrune("void", (setName + "Attr").str(), attr.getStorageType(), "attr"); if (method) - method->body() << " (*this)->setAttr(" << name << "AttrName(), attr);"; + method->body() << " (*this)->setAttr(" << getName + << "AttrName(), attr);"; }; - for (const NamedAttribute &namedAttr : op.getAttributes()) + for (const NamedAttribute &namedAttr : op.getAttributes()) { if (!namedAttr.attr.isDerivedAttr()) - emitAttrWithStorageType(namedAttr.name, namedAttr.attr); + for (auto names : llvm::zip(getSetNames(op, namedAttr.name), + getGetNames(op, namedAttr.name))) + emitAttrWithStorageType(std::get<0>(names), std::get<1>(names), + namedAttr.attr); + } } void OpEmitter::genOptionalAttrRemovers() { @@ -846,6 +929,7 @@ auto *m = opClass.addMethodAndPrune(rangeType, "getODSOperands", "unsigned", "index"); + ERROR_IF_PRUNED(m, "getODSOperands", op); auto &body = m->body(); body << formatv(valueRangeReturnCode, rangeBeginCall, "getODSOperandIndexAndLength(index)"); @@ -856,31 +940,38 @@ const auto &operand = op.getOperand(i); if (operand.name.empty()) continue; - if (operand.isOptional()) { - m = opClass.addMethodAndPrune("::mlir::Value", operand.name); - m->body() - << " auto operands = getODSOperands(" << i << ");\n" - << " return operands.empty() ? ::mlir::Value() : *operands.begin();"; - } else if (operand.isVariadicOfVariadic()) { - StringRef segmentAttr = - operand.constraint.getVariadicOfVariadicSegmentSizeAttr(); - if (isAdaptor) { - m = opClass.addMethodAndPrune("::llvm::SmallVector<::mlir::ValueRange>", - operand.name); - m->body() << llvm::formatv(variadicOfVariadicAdaptorCalcCode, - segmentAttr, i); - continue; - } + for (StringRef name : getGetNames(op, operand.name)) { + if (operand.isOptional()) { + m = opClass.addMethodAndPrune("::mlir::Value", name); + ERROR_IF_PRUNED(m, name, op); + m->body() << " auto operands = getODSOperands(" << i << ");\n" + << " return operands.empty() ? ::mlir::Value() : " + "*operands.begin();"; + } else if (operand.isVariadicOfVariadic()) { + StringRef segmentAttr = + operand.constraint.getVariadicOfVariadicSegmentSizeAttr(); + if (isAdaptor) { + m = opClass.addMethodAndPrune( + "::llvm::SmallVector<::mlir::ValueRange>", name); + ERROR_IF_PRUNED(m, name, op); + m->body() << llvm::formatv(variadicOfVariadicAdaptorCalcCode, + segmentAttr, i); + continue; + } - m = opClass.addMethodAndPrune("::mlir::OperandRangeRange", operand.name); - m->body() << " return getODSOperands(" << i << ").split(" << segmentAttr - << "Attr());"; - } else if (operand.isVariadic()) { - m = opClass.addMethodAndPrune(rangeType, operand.name); - m->body() << " return getODSOperands(" << i << ");"; - } else { - m = opClass.addMethodAndPrune("::mlir::Value", operand.name); - m->body() << " return *getODSOperands(" << i << ").begin();"; + m = opClass.addMethodAndPrune("::mlir::OperandRangeRange", name); + ERROR_IF_PRUNED(m, name, op); + m->body() << " return getODSOperands(" << i << ").split(" + << segmentAttr << "Attr());"; + } else if (operand.isVariadic()) { + m = opClass.addMethodAndPrune(rangeType, name); + ERROR_IF_PRUNED(m, name, op); + m->body() << " return getODSOperands(" << i << ");"; + } else { + m = opClass.addMethodAndPrune("::mlir::Value", name); + ERROR_IF_PRUNED(m, name, op); + m->body() << " return *getODSOperands(" << i << ").begin();"; + } } } } @@ -912,31 +1003,35 @@ const auto &operand = op.getOperand(i); if (operand.name.empty()) continue; - auto *m = opClass.addMethodAndPrune(operand.isVariadicOfVariadic() - ? "::mlir::MutableOperandRangeRange" - : "::mlir::MutableOperandRange", - (operand.name + "Mutable").str()); - auto &body = m->body(); - body << " auto range = getODSOperandIndexAndLength(" << i << ");\n" - << " auto mutableRange = ::mlir::MutableOperandRange(getOperation(), " - "range.first, range.second"; - if (attrSizedOperands) - body << ", ::mlir::MutableOperandRange::OperandSegment(" << i - << "u, *getOperation()->getAttrDictionary().getNamed(" - "operand_segment_sizesAttrName()))"; - body << ");\n"; - - // If this operand is a nested variadic, we split the range into a - // MutableOperandRangeRange that provides a range over all of the - // sub-ranges. - if (operand.isVariadicOfVariadic()) { - body << " return " - "mutableRange.split(*(*this)->getAttrDictionary().getNamed(" - << operand.constraint.getVariadicOfVariadicSegmentSizeAttr() - << "AttrName()));\n"; - } else { - // Otherwise, we use the full range directly. - body << " return mutableRange;\n"; + for (StringRef name : getGetNames(op, operand.name)) { + auto *m = opClass.addMethodAndPrune( + operand.isVariadicOfVariadic() ? "::mlir::MutableOperandRangeRange" + : "::mlir::MutableOperandRange", + (name + "Mutable").str()); + ERROR_IF_PRUNED(m, name, op); + auto &body = m->body(); + body << " auto range = getODSOperandIndexAndLength(" << i << ");\n" + << " auto mutableRange = " + "::mlir::MutableOperandRange(getOperation(), " + "range.first, range.second"; + if (attrSizedOperands) + body << ", ::mlir::MutableOperandRange::OperandSegment(" << i + << "u, *getOperation()->getAttrDictionary().getNamed(" + "operand_segment_sizesAttrName()))"; + body << ");\n"; + + // If this operand is a nested variadic, we split the range into a + // MutableOperandRangeRange that provides a range over all of the + // sub-ranges. + if (operand.isVariadicOfVariadic()) { + body << " return " + "mutableRange.split(*(*this)->getAttrDictionary().getNamed(" + << operand.constraint.getVariadicOfVariadicSegmentSizeAttr() + << "AttrName()));\n"; + } else { + // Otherwise, we use the full range directly. + body << " return mutableRange;\n"; + } } } } @@ -985,6 +1080,7 @@ auto *m = opClass.addMethodAndPrune("::mlir::Operation::result_range", "getODSResults", "unsigned", "index"); + ERROR_IF_PRUNED(m, "getODSResults", op); m->body() << formatv(valueRangeReturnCode, "getOperation()->result_begin()", "getODSResultIndexAndLength(index)"); @@ -992,18 +1088,22 @@ const auto &result = op.getResult(i); if (result.name.empty()) continue; - if (result.isOptional()) { - m = opClass.addMethodAndPrune("::mlir::Value", result.name); - m->body() - << " auto results = getODSResults(" << i << ");\n" - << " return results.empty() ? ::mlir::Value() : *results.begin();"; - } else if (result.isVariadic()) { - m = opClass.addMethodAndPrune("::mlir::Operation::result_range", - result.name); - m->body() << " return getODSResults(" << i << ");"; - } else { - m = opClass.addMethodAndPrune("::mlir::Value", result.name); - m->body() << " return *getODSResults(" << i << ").begin();"; + for (StringRef name : getGetNames(op, result.name)) { + if (result.isOptional()) { + m = opClass.addMethodAndPrune("::mlir::Value", name); + ERROR_IF_PRUNED(m, name, op); + m->body() + << " auto results = getODSResults(" << i << ");\n" + << " return results.empty() ? ::mlir::Value() : *results.begin();"; + } else if (result.isVariadic()) { + m = opClass.addMethodAndPrune("::mlir::Operation::result_range", name); + ERROR_IF_PRUNED(m, name, op); + m->body() << " return getODSResults(" << i << ");"; + } else { + m = opClass.addMethodAndPrune("::mlir::Value", name); + ERROR_IF_PRUNED(m, name, op); + m->body() << " return *getODSResults(" << i << ").begin();"; + } } } } @@ -1015,17 +1115,21 @@ if (region.name.empty()) continue; - // Generate the accessors for a variadic region. - if (region.isVariadic()) { - auto *m = opClass.addMethodAndPrune( - "::mlir::MutableArrayRef<::mlir::Region>", region.name); - m->body() << formatv(" return (*this)->getRegions().drop_front({0});", - i); - continue; - } + for (StringRef name : getGetNames(op, region.name)) { + // Generate the accessors for a variadic region. + if (region.isVariadic()) { + auto *m = opClass.addMethodAndPrune( + "::mlir::MutableArrayRef<::mlir::Region>", name); + ERROR_IF_PRUNED(m, name, op); + m->body() << formatv(" return (*this)->getRegions().drop_front({0});", + i); + continue; + } - auto *m = opClass.addMethodAndPrune("::mlir::Region &", region.name); - m->body() << formatv(" return (*this)->getRegion({0});", i); + auto *m = opClass.addMethodAndPrune("::mlir::Region &", name); + ERROR_IF_PRUNED(m, name, op); + m->body() << formatv(" return (*this)->getRegion({0});", i); + } } } @@ -1036,19 +1140,22 @@ if (successor.name.empty()) continue; - // Generate the accessors for a variadic successor list. - if (successor.isVariadic()) { - auto *m = - opClass.addMethodAndPrune("::mlir::SuccessorRange", successor.name); - m->body() << formatv( - " return {std::next((*this)->successor_begin(), {0}), " - "(*this)->successor_end()};", - i); - continue; - } + for (StringRef name : getGetNames(op, successor.name)) { + // Generate the accessors for a variadic successor list. + if (successor.isVariadic()) { + auto *m = opClass.addMethodAndPrune("::mlir::SuccessorRange", name); + ERROR_IF_PRUNED(m, name, op); + m->body() << formatv( + " return {std::next((*this)->successor_begin(), {0}), " + "(*this)->successor_end()};", + i); + continue; + } - auto *m = opClass.addMethodAndPrune("::mlir::Block *", successor.name); - m->body() << formatv(" return (*this)->getSuccessor({0});", i); + auto *m = opClass.addMethodAndPrune("::mlir::Block *", name); + ERROR_IF_PRUNED(m, name, op); + m->body() << formatv(" return (*this)->getSuccessor({0});", i); + } } } @@ -1315,8 +1422,8 @@ std::string resultType; const auto &namedAttr = op.getAttribute(0); - body << " auto attrName = " << namedAttr.name << "AttrName(" - << builderOpState + auto names = getGetNames(op, namedAttr.name); + body << " auto attrName = " << names.front() << "AttrName(" << builderOpState << ".name);\n" " for (auto attr : attributes) {\n" " if (attr.first != attrName) continue;\n"; @@ -1379,6 +1486,8 @@ body ? OpMethod::MP_Static : OpMethod::MP_StaticDeclaration; auto *method = opClass.addMethodAndPrune("void", "build", properties, paramStr); + if (body) + ERROR_IF_PRUNED(method, "build", op); FmtContext fctx; fctx.withBuilder(odsBuilder); @@ -1736,9 +1845,10 @@ SmallVector paramList; paramList.emplace_back(op.getCppClassName(), "op"); paramList.emplace_back("::mlir::PatternRewriter &", "rewriter"); - opClass.addMethodAndPrune("::mlir::LogicalResult", "canonicalize", - OpMethod::MP_StaticDeclaration, - std::move(paramList)); + auto *m = opClass.addMethodAndPrune("::mlir::LogicalResult", "canonicalize", + OpMethod::MP_StaticDeclaration, + std::move(paramList)); + ERROR_IF_PRUNED(m, "canonicalize", op); } // We get a prototype for 'getCanonicalizationPatterns' if requested directly @@ -1761,8 +1871,10 @@ "void", "getCanonicalizationPatterns", kind, std::move(paramList)); // If synthesizing the method, fill it it. - if (hasBody) + if (hasBody) { + ERROR_IF_PRUNED(method, "getCanonicalizationPatterns", op); method->body() << " results.add(canonicalize);\n"; + } } void OpEmitter::genFolderDecls() { @@ -1771,16 +1883,19 @@ if (def.getValueAsBit("hasFolder")) { if (hasSingleResult) { - opClass.addMethodAndPrune( + auto *m = opClass.addMethodAndPrune( "::mlir::OpFoldResult", "fold", OpMethod::MP_Declaration, "::llvm::ArrayRef<::mlir::Attribute>", "operands"); + ERROR_IF_PRUNED(m, "operands", op); } else { SmallVector paramList; paramList.emplace_back("::llvm::ArrayRef<::mlir::Attribute>", "operands"); paramList.emplace_back("::llvm::SmallVectorImpl<::mlir::OpFoldResult> &", "results"); - opClass.addMethodAndPrune("::mlir::LogicalResult", "fold", - OpMethod::MP_Declaration, std::move(paramList)); + auto *m = opClass.addMethodAndPrune("::mlir::LogicalResult", "fold", + OpMethod::MP_Declaration, + std::move(paramList)); + ERROR_IF_PRUNED(m, "fold", op); } } } @@ -1803,7 +1918,9 @@ if (method.getDefaultImplementation() && !alwaysDeclaredMethods.count(method.getName())) continue; - genOpInterfaceMethod(method); + // Interface methods are allowed to overlap with existing methods, so don't + // check if pruned. + (void)genOpInterfaceMethod(method); } } @@ -1895,6 +2012,7 @@ .str(); auto *getEffects = opClass.addMethodAndPrune("void", "getEffects", type, "effects"); + ERROR_IF_PRUNED(getEffects, "getEffects", op); auto &body = getEffects->body(); // Add effect instances for each of the locations marked on the operation. @@ -1944,6 +2062,7 @@ assert(0 && "unable to find inferReturnTypes interface method"); return nullptr; }(); + ERROR_IF_PRUNED(method, "inferReturnTypes", op); auto &body = method->body(); body << " inferredReturnTypes.resize(" << op.getNumResults() << ");\n"; @@ -1989,6 +2108,7 @@ auto *method = opClass.addMethodAndPrune("::mlir::ParseResult", "parse", OpMethod::MP_Static, std::move(paramList)); + ERROR_IF_PRUNED(method, "parse", op); FmtContext fctx; fctx.addSubst("cppClass", opClass.getClassName()); @@ -2007,6 +2127,7 @@ auto *method = opClass.addMethodAndPrune("void", "print", "::mlir::OpAsmPrinter &", "p"); + ERROR_IF_PRUNED(method, "print", op); FmtContext fctx; fctx.addSubst("cppClass", opClass.getClassName()); auto printer = stringInit->getValue().ltrim().rtrim(" \t\v\f\r"); @@ -2015,6 +2136,7 @@ void OpEmitter::genVerifier() { auto *method = opClass.addMethodAndPrune("::mlir::LogicalResult", "verify"); + ERROR_IF_PRUNED(method, "verify", op); auto &body = method->body(); body << " if (::mlir::failed(" << op.getAdaptorName() << "(*this).verify((*this)->getLoc()))) " @@ -2274,6 +2396,7 @@ auto *method = opClass.addMethodAndPrune( "::llvm::StringLiteral", "getOperationName", OpMethod::Property(OpMethod::MP_Static | OpMethod::MP_Constexpr)); + ERROR_IF_PRUNED(method, "getOperationName", op); method->body() << " return ::llvm::StringLiteral(\"" << op.getOperationName() << "\");"; } @@ -2301,6 +2424,7 @@ // Generate the right accessor for the number of results. auto *method = opClass.addMethodAndPrune( "void", "getAsmResultNames", "::mlir::OpAsmSetValueNameFn", "setNameFn"); + ERROR_IF_PRUNED(method, "getAsmResultNames", op); auto &body = method->body(); for (int i = 0; i != numResults; ++i) { body << " auto resultGroup" << i << " = getODSResults(" << i << ");\n" @@ -2365,6 +2489,7 @@ { auto *m = adaptor.addMethodAndPrune("::mlir::ValueRange", "getOperands"); + ERROR_IF_PRUNED(m, "getOperands", op); m->body() << " return odsOperands;"; } std::string sizeAttrInit = @@ -2380,7 +2505,9 @@ fctx.withBuilder("::mlir::Builder(odsAttrs.getContext())"); auto emitAttr = [&](StringRef name, Attribute attr) { - auto &body = adaptor.addMethodAndPrune(attr.getStorageType(), name)->body(); + auto *method = adaptor.addMethodAndPrune(attr.getStorageType(), name); + ERROR_IF_PRUNED(method, "Adaptor::" + name, op); + auto &body = method->body(); body << " assert(odsAttrs && \"no attributes when constructing adapter\");" << "\n " << attr.getStorageType() << " attr = " << "odsAttrs.get(\"" << name << "\")."; @@ -2404,6 +2531,7 @@ { auto *m = adaptor.addMethodAndPrune("::mlir::DictionaryAttr", "getAttributes"); + ERROR_IF_PRUNED(m, "Adaptor::getAttributes", op); m->body() << " return odsAttrs;"; } for (auto &namedAttr : op.getAttributes()) { @@ -2416,6 +2544,7 @@ unsigned numRegions = op.getNumRegions(); if (numRegions > 0) { auto *m = adaptor.addMethodAndPrune("::mlir::RegionRange", "getRegions"); + ERROR_IF_PRUNED(m, "Adaptor::getRegions", op); m->body() << " return odsRegions;"; } for (unsigned i = 0; i < numRegions; ++i) { @@ -2426,11 +2555,13 @@ // Generate the accessors for a variadic region. if (region.isVariadic()) { auto *m = adaptor.addMethodAndPrune("::mlir::RegionRange", region.name); + ERROR_IF_PRUNED(m, "Adaptor::" + region.name, op); m->body() << formatv(" return odsRegions.drop_front({0});", i); continue; } auto *m = adaptor.addMethodAndPrune("::mlir::Region &", region.name); + ERROR_IF_PRUNED(m, "Adaptor::" + region.name, op); m->body() << formatv(" return *odsRegions[{0}];", i); } @@ -2441,6 +2572,7 @@ void OpOperandAdaptorEmitter::addVerification() { auto *method = adaptor.addMethodAndPrune("::mlir::LogicalResult", "verify", "::mlir::Location", "loc"); + ERROR_IF_PRUNED(method, "verify", op); auto &body = method->body(); const char *checkAttrSizedValueSegmentsCode = R"(