diff --git a/mlir/test/mlir-tblgen/op-attribute.td b/mlir/test/mlir-tblgen/op-attribute.td --- a/mlir/test/mlir-tblgen/op-attribute.td +++ b/mlir/test/mlir-tblgen/op-attribute.td @@ -80,13 +80,13 @@ // --- // DEF: some-attr-kind AOp::aAttrAttr() -// DEF-NEXT: (*this)->getAttr(aAttrAttrName()).template cast() +// DEF-NEXT: (*this)->getAttr(aAttrAttrName()).cast() // DEF: some-return-type AOp::aAttr() { // DEF-NEXT: auto attr = aAttrAttr() // DEF-NEXT: return attr.some-convert-from-storage(); // DEF: some-attr-kind AOp::bAttrAttr() -// DEF-NEXT: return (*this)->getAttr(bAttrAttrName()).template dyn_cast_or_null() +// DEF-NEXT: return (*this)->getAttr(bAttrAttrName()).dyn_cast_or_null() // DEF: some-return-type AOp::bAttr() { // DEF-NEXT: auto attr = bAttrAttr(); // DEF-NEXT: if (!attr) @@ -94,7 +94,7 @@ // DEF-NEXT: return attr.some-convert-from-storage(); // DEF: some-attr-kind AOp::cAttrAttr() -// DEF-NEXT: return (*this)->getAttr(cAttrAttrName()).template dyn_cast_or_null() +// DEF-NEXT: return (*this)->getAttr(cAttrAttrName()).dyn_cast_or_null() // DEF: ::llvm::Optional AOp::cAttr() { // DEF-NEXT: auto attr = cAttrAttr() // DEF-NEXT: return attr ? ::llvm::Optional(attr.some-convert-from-storage()) : (::llvm::None); @@ -194,13 +194,13 @@ // --- // DEF: some-attr-kind AgetOp::getAAttrAttr() -// DEF-NEXT: (*this)->getAttr(getAAttrAttrName()).template cast() +// DEF-NEXT: (*this)->getAttr(getAAttrAttrName()).cast() // DEF: some-return-type AgetOp::getAAttr() { // DEF-NEXT: auto attr = getAAttrAttr() // DEF-NEXT: return attr.some-convert-from-storage(); // DEF: some-attr-kind AgetOp::getBAttrAttr() -// DEF-NEXT: return (*this)->getAttr(getBAttrAttrName()).template dyn_cast_or_null() +// DEF-NEXT: return (*this)->getAttr(getBAttrAttrName()).dyn_cast_or_null() // DEF: some-return-type AgetOp::getBAttr() { // DEF-NEXT: auto attr = getBAttrAttr(); // DEF-NEXT: if (!attr) @@ -208,7 +208,7 @@ // DEF-NEXT: return attr.some-convert-from-storage(); // DEF: some-attr-kind AgetOp::getCAttrAttr() -// DEF-NEXT: return (*this)->getAttr(getCAttrAttrName()).template dyn_cast_or_null() +// DEF-NEXT: return (*this)->getAttr(getCAttrAttrName()).dyn_cast_or_null() // DEF: ::llvm::Optional AgetOp::getCAttr() { // DEF-NEXT: auto attr = getCAttrAttr() // DEF-NEXT: return attr ? ::llvm::Optional(attr.some-convert-from-storage()) : (::llvm::None); diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -42,13 +42,7 @@ static const char *const odsBuilder = "odsBuilder"; static const char *const builderOpState = "odsState"; -// Code for OpAdaptors to lookup an attribute using strings on the provided -// DictionaryAttr. -// -// {0}: The attribute name. -static const char *const adaptorGetAttr = "odsAttrs.get(\"{0}\")"; - -// Code for Ops to lookup an attribute using the cached identifier. +// Code for an Op to lookup an attribute. Uses cached identifiers. // // {0}: The attribute's getter name. static const char *const opGetAttr = "(*this)->getAttr({0}AttrName())"; @@ -63,7 +57,7 @@ // {2}: The total number of variadic operands/results. // {3}: The total number of actual values. // {4}: "operand" or "result". -const char *sameVariadicSizeValueRangeCalcCode = R"( +static const char *const sameVariadicSizeValueRangeCalcCode = R"( bool isVariadic[] = {{{0}}; int prevVariadicCount = 0; for (unsigned i = 0; i < index; ++i) @@ -87,14 +81,15 @@ // (variadic or not). // // {0}: The name of the attribute specifying the segment sizes. -const char *adapterSegmentSizeAttrInitCode = R"( +static const char *const adapterSegmentSizeAttrInitCode = R"( assert(odsAttrs && "missing segment size attribute for op"); auto sizeAttr = odsAttrs.get("{0}").cast<::mlir::DenseIntElementsAttr>(); )"; -const char *opSegmentSizeAttrInitCode = R"( - auto sizeAttr = (*this)->getAttr({0}).cast<::mlir::DenseIntElementsAttr>(); +static const char *const opSegmentSizeAttrInitCode = R"( + auto sizeAttr = + (*this)->getAttr({0}AttrName()).cast<::mlir::DenseIntElementsAttr>(); )"; -const char *attrSizedSegmentValueRangeCalcCode = R"( +static const char *const attrSizedSegmentValueRangeCalcCode = R"( const uint32_t *sizeAttrValueIt = &*sizeAttr.value_begin(); if (sizeAttr.isSplat()) return {*sizeAttrValueIt * index, *sizeAttrValueIt}; @@ -109,7 +104,7 @@ // // {0}: The name of the segment attribute. // {1}: The index of the main operand. -const char *variadicOfVariadicAdaptorCalcCode = R"( +static const char *const variadicOfVariadicAdaptorCalcCode = R"( auto tblgenTmpOperands = getODSOperands({1}); auto sizeAttrValues = {0}().getValues(); auto sizeAttrIt = sizeAttrValues.begin(); @@ -126,17 +121,17 @@ // // {0}: The begin iterator of the actual values. // {1}: The call to generate the start and length of the value range. -const char *valueRangeReturnCode = R"( +static const char *const valueRangeReturnCode = R"( auto valueRange = {1}; return {{std::next({0}, valueRange.first), std::next({0}, valueRange.first + valueRange.second)}; )"; -const char *typeVerifierSignature = +static const char *const typeVerifierSignature = "static ::mlir::LogicalResult {0}(::mlir::Operation *op, ::mlir::Type " "type, ::llvm::StringRef valueKind, unsigned valueGroupStartIndex)"; -const char *typeVerifierErrorHandler = +static const char *const typeVerifierErrorHandler = " op->emitOpError(valueKind) << \" #\" << valueGroupStartIndex << \" must " "be {0}, but got \" << type"; @@ -166,7 +161,7 @@ // via getValueAsString. static inline bool hasStringAttribute(const Record &record, StringRef fieldName) { - auto valueInit = record.getValueInit(fieldName); + auto *valueInit = record.getValueInit(fieldName); return isa(valueInit); } @@ -185,6 +180,90 @@ !attr.getConstBuilderTemplate().empty(); } +namespace { +/// Helper class to select between OpAdaptor and Op code templates. +class OpOrAdaptorHelper { +public: + OpOrAdaptorHelper(const Operator &op, bool emitForOp) + : op(op), emitForOp(emitForOp) {} + + /// Object that wraps a functor in a stream operator for interop with + /// llvm::formatv. + class Formatter { + public: + template + Formatter(Functor &&func) : func(std::forward(func)) {} + + std::string str() const { + std::string result; + llvm::raw_string_ostream os(result); + os << *this; + return os.str(); + } + + private: + std::function func; + + friend raw_ostream &operator<<(raw_ostream &os, const Formatter &fmt) { + return fmt.func(os); + } + }; + + // Generate code for getting an attribute. + Formatter getAttr(StringRef attrName) const { + return [this, attrName](raw_ostream &os) -> raw_ostream & { + if (!emitForOp) + return os << formatv("odsAttrs.get(\"{0}\")", attrName); + return os << formatv(opGetAttr, op.getGetterName(attrName)); + }; + } + + // Get the prefix code for emitting an error. + Formatter emitErrorPrefix() const { + return [this](raw_ostream &os) -> raw_ostream & { + if (emitForOp) + return os << "emitOpError("; + return os << formatv("emitError(loc, \"'{0}' op \"", + op.getOperationName()); + }; + } + + // Get the call to get an operand or segment of operands. + Formatter getOperand(unsigned index) const { + return [this, index](raw_ostream &os) -> raw_ostream & { + return os << formatv(op.getOperand(index).isVariadic() + ? "this->getODSOperands({0})" + : "(*this->getODSOperands({0}).begin())", + index); + }; + } + + // Get the call to get a result of segment of results. + Formatter getResult(unsigned index) const { + return [this, index](raw_ostream &os) -> raw_ostream & { + if (!emitForOp) + return os << ""; + return os << formatv(op.getResult(index).isVariadic() + ? "this->getODSResults({0})" + : "(*this->getODSResults({0}).begin())", + index); + }; + } + + // Return whether an op instance is available. + bool isEmittingForOp() const { return emitForOp; } + + // Return the ODS operation wrapper. + const Operator &getOp() const { return op; } + +private: + // The operation ODS wrapper. + const Operator &op; + // True if code is being generate for an op. False for an adaptor. + const bool emitForOp; +}; +} // end anonymous namespace + //===----------------------------------------------------------------------===// // Op emitter //===----------------------------------------------------------------------===// @@ -374,57 +453,31 @@ // Populate the format context `ctx` with substitutions of attributes, operands // and results. -// - attrGet corresponds to the name of the function to call to get value of -// attribute (the generated function call returns an Attribute); -// - operandGet corresponds to the name of the function with which to retrieve -// an operand (the generated function call returns an OperandRange); -// - resultGet corresponds to the name of the function to get an result (the -// generated function call returns a ValueRange); -// - opRequired whether an op instance is needed -static void populateSubstitutions(const Operator &op, const char *attrGet, - const char *operandGet, const char *resultGet, - FmtContext &ctx, bool opRequired) { - // Populate substitutions for attributes and named operands. - for (const auto &namedAttr : op.getAttributes()) { - ctx.addSubst(namedAttr.name, - formatv(attrGet, opRequired ? op.getGetterName(namedAttr.name) - : namedAttr.name)); - } +static void populateSubstitutions(const OpOrAdaptorHelper &emitHelper, + FmtContext &ctx) { + // Populate substitutions for attributes. + auto &op = emitHelper.getOp(); + for (const auto &namedAttr : op.getAttributes()) + ctx.addSubst(namedAttr.name, emitHelper.getAttr(namedAttr.name).str()); + + // Populate substitutions for named operands. for (int i = 0, e = op.getNumOperands(); i < e; ++i) { auto &value = op.getOperand(i); - if (value.name.empty()) - continue; - - if (value.isVariadic()) - ctx.addSubst(value.name, formatv("{0}({1})", operandGet, i)); - else - ctx.addSubst(value.name, formatv("(*{0}({1}).begin())", operandGet, i)); + if (!value.name.empty()) + ctx.addSubst(value.name, emitHelper.getOperand(i).str()); } // Populate substitutions for results. for (int i = 0, e = op.getNumResults(); i < e; ++i) { auto &value = op.getResult(i); - if (value.name.empty()) - continue; - - if (value.isVariadic()) - ctx.addSubst(value.name, formatv("{0}({1})", resultGet, i)); - else - ctx.addSubst(value.name, formatv("(*{0}({1}).begin())", resultGet, i)); + if (!value.name.empty()) + ctx.addSubst(value.name, emitHelper.getResult(i).str()); } } -// Generate attribute verification. If emitVerificationRequiringOp is set then -// only verification for attributes whose value depend on op being known are -// emitted, else only verification that doesn't depend on the op being known are -// generated. -// - emitErrorPrefix is the prefix for the error emitting call which consists -// of the entire function call up to start of error message fragment; -// - emitVerificationRequiringOp specifies whether verification should be -// emitted for verification that require the op to exist; -static void genAttributeVerifier(const Operator &op, const char *attrGet, - const Twine &emitErrorPrefix, - bool emitVerificationRequiringOp, +// Generate attribute verification. If an op instance is not available, then +// attribute checks that require one will not be emitted. +static void genAttributeVerifier(const OpOrAdaptorHelper &emitHelper, FmtContext &ctx, OpMethodBody &body) { // Check that a required attribute exists. // @@ -448,7 +501,7 @@ return {2}"attribute '{3}' failed to satisfy constraint: {4}"); )"; - for (const auto &namedAttr : op.getAttributes()) { + for (const auto &namedAttr : emitHelper.getOp().getAttributes()) { const auto &attr = namedAttr.attr; StringRef attrName = namedAttr.name; if (attr.isDerivedAttr()) @@ -460,7 +513,7 @@ // If the attribute's condition needs an op but none is available, then the // condition cannot be emitted. bool canEmitCondition = - !StringRef(condition).contains("$_op") || emitVerificationRequiringOp; + !StringRef(condition).contains("$_op") || emitHelper.isEmittingForOp(); // Prefix with `tblgen_` to avoid hiding the attribute accessor. Twine varName = tblgenNamePrefix + attrName; @@ -471,16 +524,17 @@ continue; body << formatv(" {\n auto {0} = {1};", varName, - formatv(attrGet, emitVerificationRequiringOp - ? op.getGetterName(attrName) - : attrName)); + emitHelper.getAttr(attrName)); - if (!allowMissingAttr) - body << formatv(checkRequiredAttr, varName, emitErrorPrefix, attrName); + if (!allowMissingAttr) { + body << formatv(checkRequiredAttr, varName, emitHelper.emitErrorPrefix(), + attrName); + } if (canEmitCondition) { body << formatv(checkAttrCondition, varName, - tgfmt(condition, &ctx.withSelf(varName)), emitErrorPrefix, - attrName, escapeString(attr.getSummary())); + tgfmt(condition, &ctx.withSelf(varName)), + emitHelper.emitErrorPrefix(), attrName, + escapeString(attr.getSummary())); } body << "}\n"; } @@ -685,13 +739,11 @@ opClass.addMethodAndPrune(attr.getStorageType(), (name + "Attr").str()); if (!method) return; - auto &body = method->body(); - body << " return (*this)->getAttr(" << name << "AttrName()).template "; - if (attr.isOptional() || attr.hasDefaultValue()) - body << "dyn_cast_or_null<"; - else - body << "cast<"; - body << attr.getStorageType() << ">();"; + method->body() << formatv( + " return {0}.{1}<{2}>();", formatv(opGetAttr, name), + attr.isOptional() || attr.hasDefaultValue() ? "dyn_cast_or_null" + : "cast", + attr.getStorageType()); }; for (const NamedAttribute &namedAttr : op.getAttributes()) { @@ -783,8 +835,8 @@ auto *method = opClass.addMethodAndPrune( "void", (setterName + "Attr").str(), attr.getStorageType(), "attr"); if (method) - method->body() << " (*this)->setAttr(" << getterName - << "AttrName(), attr);"; + method->body() << formatv(" (*this)->setAttr({0}AttrName(), attr);", + getterName); }; for (const NamedAttribute &namedAttr : op.getAttributes()) { @@ -806,8 +858,8 @@ "::mlir::Attribute", ("remove" + upperInitial + suffix + "Attr").str()); if (!method) return; - method->body() << " return (*this)->removeAttr(" << op.getGetterName(name) - << "AttrName());"; + method->body() << formatv(" return (*this)->removeAttr({0}AttrName());", + op.getGetterName(name)); }; for (const NamedAttribute &namedAttr : op.getAttributes()) @@ -949,7 +1001,7 @@ // array. std::string attrSizeInitCode; if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { - std::string attr = op.getGetterName("operand_segment_sizes") + "AttrName()"; + std::string attr = op.getGetterName("operand_segment_sizes"); attrSizeInitCode = formatv(opSegmentSizeAttrInitCode, attr).str(); } @@ -1037,7 +1089,7 @@ // Build the initializer string for the result segment size attribute. std::string attrSizeInitCode; if (attrSizedResults) { - std::string attr = op.getGetterName("result_segment_sizes") + "AttrName()"; + std::string attr = op.getGetterName("result_segment_sizes"); attrSizeInitCode = formatv(opSegmentSizeAttrInitCode, attr).str(); } @@ -2037,18 +2089,16 @@ auto emitType = [&](const tblgen::Operator::ArgOrType &type) -> OpMethodBody & { - if (type.isArg()) { - auto argIndex = type.getArg(); - assert(!op.getArg(argIndex).is()); - auto arg = op.getArgToOperandOrAttribute(argIndex); - if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) - return body << "operands[" << arg.operandOrAttributeIndex() - << "].getType()"; - return body << "attributes[" << arg.operandOrAttributeIndex() - << "].getType()"; - } else { + if (!type.isArg()) return body << tgfmt(*type.getType().getBuilderCall(), &fctx); - } + auto argIndex = type.getArg(); + assert(!op.getArg(argIndex).is()); + auto arg = op.getArgToOperandOrAttribute(argIndex); + if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) + return body << "operands[" << arg.operandOrAttributeIndex() + << "].getType()"; + return body << "attributes[" << arg.operandOrAttributeIndex() + << "].getType()"; }; for (int i = 0, e = op.getNumResults(); i != e; ++i) { @@ -2085,7 +2135,7 @@ if (hasStringAttribute(def, "assemblyFormat")) return; - auto valueInit = def.getValueInit("printer"); + auto *valueInit = def.getValueInit("printer"); StringInit *stringInit = dyn_cast(valueInit); if (!stringInit) return; @@ -2100,10 +2150,8 @@ } /// Generate verification on native traits requiring attributes. -static void genNativeTraitAttrVerifier(OpMethodBody &body, const Operator &op, - const char *const attrGet, - const Twine &emitError, - bool opRequired) { +static void genNativeTraitAttrVerifier(OpMethodBody &body, + const OpOrAdaptorHelper &emitHelper) { // Check that the variadic segment sizes attribute exists and contains the // expected number of elements. // @@ -2127,6 +2175,7 @@ // Verify a few traits first so that we can use getODSOperands() and // getODSResults() in the rest of the verifier. + auto &op = emitHelper.getOp(); for (auto &trait : op.getTraits()) { auto *t = dyn_cast(&trait); if (!t) @@ -2134,18 +2183,15 @@ std::string traitName = t->getFullyQualifiedTraitName(); if (traitName == "::mlir::OpTrait::AttrSizedOperandSegments") { StringRef attrName = "operand_segment_sizes"; - body << formatv( - checkAttrSizedValueSegmentsCode, attrName, op.getNumOperands(), - "operand", - formatv(attrGet, opRequired ? op.getGetterName(attrName) : attrName), - emitError); + body << formatv(checkAttrSizedValueSegmentsCode, attrName, + op.getNumOperands(), "operand", + emitHelper.getAttr(attrName), + emitHelper.emitErrorPrefix()); } else if (traitName == "::mlir::OpTrait::AttrSizedResultSegments") { StringRef attrName = "result_segment_sizes"; body << formatv( checkAttrSizedValueSegmentsCode, attrName, op.getNumResults(), - "result", - formatv(attrGet, opRequired ? op.getGetterName(attrName) : attrName), - emitError); + "result", emitHelper.getAttr(attrName), emitHelper.emitErrorPrefix()); } } } @@ -2155,16 +2201,15 @@ ERROR_IF_PRUNED(method, "verify", op); auto &body = method->body(); - genNativeTraitAttrVerifier(body, op, opGetAttr, "emitOpError(", true); + OpOrAdaptorHelper emitHelper(op, /*isOp=*/true); + genNativeTraitAttrVerifier(body, emitHelper); auto *valueInit = def.getValueInit("verifier"); StringInit *stringInit = dyn_cast(valueInit); bool hasCustomVerify = stringInit && !stringInit->getValue().empty(); - populateSubstitutions(op, opGetAttr, "this->getODSOperands", - "this->getODSResults", verifyCtx, /*opRequired=*/true); + populateSubstitutions(emitHelper, verifyCtx); - genAttributeVerifier(op, opGetAttr, "emitOpError(", - /*emitVerificationRequiringOp=*/true, verifyCtx, body); + genAttributeVerifier(emitHelper, verifyCtx, body); genOperandResultVerifier(body, op.getOperands(), "operand"); genOperandResultVerifier(body, op.getResults(), "result"); @@ -2400,9 +2445,9 @@ // Add the native and interface traits. for (const auto &trait : op.getTraits()) { - if (auto opTrait = dyn_cast(&trait)) + if (auto *opTrait = dyn_cast(&trait)) opClass.addTrait(opTrait->getFullyQualifiedTraitName()); - else if (auto opTrait = dyn_cast(&trait)) + else if (auto *opTrait = dyn_cast(&trait)) opClass.addTrait(opTrait->getFullyQualifiedTraitName()); } } @@ -2472,7 +2517,7 @@ const Operator &op; Class adaptor; }; -} // end namespace +} // end anonymous namespace OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op) : op(op), adaptor(op.getAdaptorName()) { @@ -2594,17 +2639,12 @@ ERROR_IF_PRUNED(method, "verify", op); auto &body = method->body(); - std::string emitError = - "emitError(loc, \"'" + op.getOperationName() + "' op \""; - genNativeTraitAttrVerifier(body, op, adaptorGetAttr, emitError, - /*opRequired=*/false); - + OpOrAdaptorHelper emitHelper(op, /*isOp=*/false); + genNativeTraitAttrVerifier(body, emitHelper); FmtContext verifyCtx; - populateSubstitutions(op, adaptorGetAttr, "getODSOperands", - "", verifyCtx, - /*opRequired=*/false); - genAttributeVerifier(op, adaptorGetAttr, emitError, - /*emitVerificationRequiringOp*/ false, verifyCtx, body); + populateSubstitutions(emitHelper, verifyCtx); + + genAttributeVerifier(emitHelper, verifyCtx, body); body << " return ::mlir::success();"; }