diff --git a/mlir/test/mlir-tblgen/op-format-invalid.td b/mlir/test/mlir-tblgen/op-format-invalid.td --- a/mlir/test/mlir-tblgen/op-format-invalid.td +++ b/mlir/test/mlir-tblgen/op-format-invalid.td @@ -208,6 +208,13 @@ def DirectiveRegionsInvalidC : TestFormat_Op<[{ type(regions) }]>; +// CHECK: error: format ambiguity caused by `attr-dict` directive followed by region `foo` +// CHECK: note: try using `attr-dict-with-keyword` instead +def DirectiveRegionsInvalidD : TestFormat_Op<[{ + attr-dict $foo +}]> { + let regions = (region AnyRegion:$foo); +} //===----------------------------------------------------------------------===// // results diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -2213,6 +2213,9 @@ /// Verify that attributes elements aren't followed by colon literals. LogicalResult verifyAttributeColonType(SMLoc loc, ArrayRef elements); + /// Verify that the attribute dictionary directive isn't followed by a region. + LogicalResult verifyAttrDictRegion(SMLoc loc, + ArrayRef elements); /// Verify the state of operation operands within the format. LogicalResult @@ -2349,6 +2352,11 @@ // better to just error out here instead. if (failed(verifyAttributeColonType(loc, elements))) return failure(); + // Check that there are no region variables following an attribute dicitonary. + // Both start with `{` and so the optional attribute dictionary can cause + // format ambiguities. + if (failed(verifyAttrDictRegion(loc, elements))) + return failure(); // Check for VariadicOfVariadic variables. The segment attribute of those // variables will be infered. @@ -2380,48 +2388,46 @@ return isa(el); } -/// Scan the given range of elements from the start for a colon literal, -/// skipping any optionally-parsed elements. If an optional group is -/// encountered, this function recurses into the 'then' and 'else' elements to -/// check if they are invalid. Returns `success` if the range is known to be -/// valid or `None` if scanning reached the end. +/// Scan the given range of elements from the start for an invalid format +/// element that satisfies `isInvalid`, skipping any optionally-parsed elements. +/// If an optional group is encountered, this function recurses into the 'then' +/// and 'else' elements to check if they are invalid. Returns `success` if the +/// range is known to be valid or `None` if scanning reached the end. /// /// Since the guard element of an optional group is required, this function /// accepts an optional element pointer to mark it as required. -static Optional checkElementRangeForColon( - function_ref emitError, StringRef attrName, +static Optional checkRangeForElement( + FormatElement *base, + function_ref isInvalid, iterator_range::iterator> elementRange, FormatElement *optionalGuard = nullptr) { for (FormatElement *element : elementRange) { - // Skip optionally parsed elements. - if (element != optionalGuard && isOptionallyParsed(element)) - continue; + // If we encounter an invalid element, return an error. + if (isInvalid(base, element)) + return failure(); // Recurse on optional groups. if (auto *optional = dyn_cast(element)) { - if (Optional result = checkElementRangeForColon( - emitError, attrName, optional->getThenElements(), + if (Optional result = checkRangeForElement( + base, isInvalid, optional->getThenElements(), // The optional group guard is required for the group. optional->getThenElements().front())) if (failed(*result)) return failure(); - if (Optional result = checkElementRangeForColon( - emitError, attrName, optional->getElseElements())) + if (Optional result = checkRangeForElement( + base, isInvalid, optional->getElseElements())) if (failed(*result)) return failure(); // Skip the optional group. continue; } - // If we encounter anything other than `:`, this range is range. - auto *literal = dyn_cast(element); - if (!literal || literal->getSpelling() != ":") - return success(); - // If we encounter `:`, the range is known to be invalid. - return emitError( - llvm::formatv("format ambiguity caused by `:` literal found after " - "attribute `{0}` which does not have a buildable type", - attrName)); + // Skip optionally parsed elements. + if (element != optionalGuard && isOptionallyParsed(element)) + continue; + + // We found a closing element that is valid. + return success(); } // Return None to indicate that we reached the end. return llvm::None; @@ -2431,46 +2437,42 @@ /// literal, resulting in an ambiguous assembly format. Returns a non-null /// attribute if verification of said attribute reached the end of the range. /// Returns null if all attribute elements are verified. -static FailureOr -verifyAttributeColon(function_ref emitError, - ArrayRef elements) { +static FailureOr verifyAdjacentElements( + function_ref isBase, + function_ref isInvalid, + ArrayRef elements) { for (auto *it = elements.begin(), *e = elements.end(); it != e; ++it) { // The current attribute being verified. - AttributeVariable *attr = nullptr; - - if ((attr = dyn_cast(*it))) { - // Check only attributes without type builders or that are known to call - // the generic attribute parser. - if (attr->getTypeBuilder() || - !(attr->shouldBeQualified() || - attr->getVar()->attr.getStorageType() == "::mlir::Attribute")) - continue; + FormatElement *base; + + if (isBase(*it)) { + base = *it; } else if (auto *optional = dyn_cast(*it)) { // Recurse on optional groups. - FailureOr thenResult = - verifyAttributeColon(emitError, optional->getThenElements()); + FailureOr thenResult = verifyAdjacentElements( + isBase, isInvalid, optional->getThenElements()); if (failed(thenResult)) return failure(); - FailureOr elseResult = - verifyAttributeColon(emitError, optional->getElseElements()); + FailureOr elseResult = verifyAdjacentElements( + isBase, isInvalid, optional->getElseElements()); if (failed(elseResult)) return failure(); // If either optional group has an unverified attribute, save it. // Otherwise, move on to the next element. - if (!(attr = *thenResult) && !(attr = *elseResult)) + if (!(base = *thenResult) && !(base = *elseResult)) continue; } else { continue; } // Verify subsequent elements for potential ambiguities. - if (Optional result = checkElementRangeForColon( - emitError, attr->getVar()->name, {std::next(it), e})) { + if (Optional result = + checkRangeForElement(base, isInvalid, {std::next(it), e})) { if (failed(*result)) return failure(); } else { // Since we reached the end, return the attribute as unverified. - return attr; + return base; } } // All attribute elements are known to be verified. @@ -2480,8 +2482,52 @@ LogicalResult OpFormatParser::verifyAttributeColonType(SMLoc loc, ArrayRef elements) { - return verifyAttributeColon( - [&](const Twine &msg) { return emitError(loc, msg); }, elements); + auto isBase = [](FormatElement *el) { + auto attr = dyn_cast(el); + if (!attr) + return false; + // Check only attributes without type builders or that are known to call + // the generic attribute parser. + return !attr->getTypeBuilder() && + (attr->shouldBeQualified() || + attr->getVar()->attr.getStorageType() == "::mlir::Attribute"); + }; + auto isInvalid = [&](FormatElement *base, FormatElement *el) { + auto *literal = dyn_cast(el); + if (!literal || literal->getSpelling() != ":") + return false; + // If we encounter `:`, the range is known to be invalid. + (void)emitError( + loc, + llvm::formatv("format ambiguity caused by `:` literal found after " + "attribute `{0}` which does not have a buildable type", + cast(base)->getVar()->name)); + return true; + }; + return verifyAdjacentElements(isBase, isInvalid, elements); +} + +LogicalResult +OpFormatParser::verifyAttrDictRegion(SMLoc loc, + ArrayRef elements) { + auto isBase = [](FormatElement *el) { + if (auto *attrDict = dyn_cast(el)) + return !attrDict->isWithKeyword(); + return false; + }; + auto isInvalid = [&](FormatElement *base, FormatElement *el) { + auto *region = dyn_cast(el); + if (!region) + return false; + (void)emitErrorAndNote( + loc, + llvm::formatv("format ambiguity caused by `attr-dict` directive " + "followed by region `{0}`", + region->getVar()->name), + "try using `attr-dict-with-keyword` instead"); + return true; + }; + return verifyAdjacentElements(isBase, isInvalid, elements); } LogicalResult OpFormatParser::verifyOperands(