diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -1884,6 +1884,34 @@ }); } +void collect(FormatElement *element, + SmallVectorImpl &variables) { + TypeSwitch(element) + .Case([&](VariableElement *var) { variables.emplace_back(var); }) + .Case([&](CustomDirective *ele) { + for (FormatElement *arg : ele->getArguments()) { + collect(arg, variables); + } + }) + .Case([&](OptionalElement *ele) { + for (FormatElement *arg : ele->getThenElements()) { + collect(arg, variables); + } + for (FormatElement *arg : ele->getElseElements()) { + collect(arg, variables); + } + }) + .Case([&](FunctionalTypeDirective *funcType) { + collect(funcType->getInputs(), variables); + collect(funcType->getResults(), variables); + }) + .Case([&](OIListElement *oilist) { + for (ArrayRef arg : oilist->getParsingElements()) + for (FormatElement *arg_ : arg) + collect(arg_, variables); + }); +} + void OperationFormat::genElementPrinter(FormatElement *element, MethodBody &body, Operator &op, bool &shouldEmitSpace, @@ -1950,8 +1978,38 @@ LiteralElement *lelement = std::get<0>(clause); ArrayRef pelement = std::get<1>(clause); + SmallVector vars; + for (FormatElement *el : pelement) { + collect(el, vars); + } body << " if ((*this)->hasAttrOfType(\"" - << lelement->getSpelling() << "\")) {\n"; + << lelement->getSpelling() << "\")"; + for (auto *var : vars) { + TypeSwitch(var) + .Case([&](AttributeVariable *attrEle) { + body << " || " << op.getGetterName(attrEle->getVar()->name) + << "Attr()"; + }) + .Case([&](OperandVariable *ele) { + if (ele->getVar()->isVariadic()) + body << " || " << op.getGetterName(ele->getVar()->name) + << "().size()"; + else + body << " || " << op.getGetterName(ele->getVar()->name) << "()"; + }) + .Case([&](ResultVariable *ele) { + if (ele->getVar()->isVariadic()) + body << " || " << op.getGetterName(ele->getVar()->name) + << "().size()"; + else + body << " || " << op.getGetterName(ele->getVar()->name) << "()"; + }) + .Case([&](RegionVariable *reg) { + body << " || " << op.getGetterName(reg->getVar()->name) << "()"; + }); + } + + body << ") {\n"; genLiteralPrinter(lelement->getSpelling(), body, shouldEmitSpace, lastWasPunctuation); for (FormatElement *element : pelement) { @@ -2877,51 +2935,46 @@ LogicalResult OpFormatParser::verifyOIListParsingElement(FormatElement *element, SMLoc loc) { - return TypeSwitch(element) - // Only optional attributes can be within an oilist parsing group. - .Case([&](AttributeVariable *attrEle) { - if (!attrEle->getVar()->attr.isOptional()) - return emitError(loc, "only optional attributes can be used to " - "in an oilist parsing group"); - return success(); - }) - // Only optional-like(i.e. variadic) operands can be within an oilist - // parsing group. - .Case([&](OperandVariable *ele) { - if (!ele->getVar()->isVariableLength()) - return emitError(loc, "only variable length operands can be " - "used within an oilist parsing group"); - return success(); - }) - // Only optional-like(i.e. variadic) results can be within an oilist - // parsing group. - .Case([&](ResultVariable *ele) { - if (!ele->getVar()->isVariableLength()) - return emitError(loc, "only variable length results can be " - "used within an oilist parsing group"); - return success(); - }) - .Case([&](RegionVariable *) { - // TODO: When ODS has proper support for marking "optional" regions, add - // a check here. - return success(); - }) - .Case([&](TypeDirective *ele) { - return verifyOIListParsingElement(ele->getArg(), loc); - }) - .Case([&](FunctionalTypeDirective *ele) { - if (failed(verifyOIListParsingElement(ele->getInputs(), loc))) - return failure(); - return verifyOIListParsingElement(ele->getResults(), loc); - }) - // Literals, whitespace, and custom directives may be used. - .Case( - [&](FormatElement *) { return success(); }) - .Default([&](FormatElement *) { - return emitError(loc, "only literals, types, and variables can be " - "used within an oilist group"); - }); + SmallVector vars; + collect(element, vars); + for (VariableElement *elem : vars) { + LogicalResult res = + TypeSwitch(elem) + // Only optional attributes can be within an oilist parsing group. + .Case([&](AttributeVariable *attrEle) { + if (!attrEle->getVar()->attr.isOptional() && + !attrEle->getVar()->attr.hasDefaultValue()) + return emitError(loc, "only optional attributes can be used in " + "an oilist parsing group"); + return success(); + }) + // Only optional-like(i.e. variadic) operands can be within an + // oilist parsing group. + .Case([&](OperandVariable *ele) { + if (!ele->getVar()->isVariableLength()) + return emitError(loc, "only variable length operands can be " + "used within an oilist parsing group"); + return success(); + }) + // Only optional-like(i.e. variadic) results can be within an oilist + // parsing group. + .Case([&](ResultVariable *ele) { + if (!ele->getVar()->isVariableLength()) + return emitError(loc, "only variable length results can be " + "used within an oilist parsing group"); + return success(); + }) + .Case([&](RegionVariable *) { return success(); }) + .Default([&](FormatElement *) { + return emitError(loc, + "only literals, types, and variables can be " + "used within an oilist group"); + }); + if (failed(res)) { + return failure(); + } + } + return success(); } FailureOr OpFormatParser::parseTypeDirective(SMLoc loc,