diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2247,6 +2247,12 @@ }]; } +def FormatLiteralFollowingOptionalGroup + : TEST_Op<"format_literal_following_optional_group"> { + let arguments = (ins TypeAttr:$type, OptionalAttr:$value); + let assemblyFormat = "(`(` $value^ `)`)? `:` $type attr-dict"; +} + //===----------------------------------------------------------------------===// // AllTypesMatch type inference diff --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td --- a/mlir/test/mlir-tblgen/op-format-spec.td +++ b/mlir/test/mlir-tblgen/op-format-spec.td @@ -505,18 +505,6 @@ }]> { let successors = (successor AnySuccessor:$successor); } -// CHECK: error: format ambiguity caused by `:` literal found after attribute `attr` which does not have a buildable type -def VariableInvalidH : TestFormat_Op<[{ - $attr `:` attr-dict -}]>, Arguments<(ins ElementsAttr:$attr)>; -// CHECK: error: format ambiguity caused by `:` literal found after attribute `attr` which does not have a buildable type -def VariableInvalidI : TestFormat_Op<[{ - (`foo` $attr^)? `:` attr-dict -}]>, Arguments<(ins OptionalAttr:$attr)>; -// CHECK: error: format ambiguity caused by `:` literal found after attribute `attr` which does not have a buildable type -def VariableInvalidJ : TestFormat_Op<[{ - $attr ` ` `:` attr-dict -}]>, Arguments<(ins ElementsAttr:$attr)>; // CHECK: error: region 'region' is already bound def VariableInvalidK : TestFormat_Op<[{ $region $region attr-dict diff --git a/mlir/test/mlir-tblgen/op-format-verify.td b/mlir/test/mlir-tblgen/op-format-verify.td new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-tblgen/op-format-verify.td @@ -0,0 +1,137 @@ +// RUN: mlir-tblgen -gen-op-decls -asmformat-error-is-fatal=false -I %S/../../include %s -o=%t 2>&1 | FileCheck %s + +include "mlir/IR/OpBase.td" + +def TestDialect : Dialect { + let name = "test"; +} +class TestFormat_Op traits = []> + : Op { + let assemblyFormat = fmt; +} + +//===----------------------------------------------------------------------===// +// Format ambiguity caused by attribute followed by colon literal +//===----------------------------------------------------------------------===// + +// Test attribute followed by a colon. +// CHECK: error: format ambiguity caused by `:` literal found after attribute `attr` +def AmbiguousTypeA : TestFormat_Op<[{ + $attr `:` attr-dict +}]>, Arguments<(ins AnyAttr:$attr)>; + +// Test optional attribute followed by colon. +// CHECK: error: format ambiguity caused by `:` literal found after attribute `attr` +def AmbiguousTypeB : TestFormat_Op<[{ + (`foo` $attr^)? `:` attr-dict +}]>, Arguments<(ins OptionalAttr:$attr)>; + +// Test attribute followed by whitespace and then colon. +// CHECK: error: format ambiguity caused by `:` literal found after attribute `attr` +def AmbiguousTypeC : TestFormat_Op<[{ + $attr ` ` `:` attr-dict +}]>, Arguments<(ins AnyAttr:$attr)>; + +// Test attribute followed by optional dictionary and then colon. +// CHECK: error: format ambiguity caused by `:` literal found after attribute `attr` +def AmbiguousTypeD : TestFormat_Op<[{ + $attr attr-dict `:` +}]>, Arguments<(ins AnyAttr:$attr)>; + +// Test attribute followed by optional group and then colon. +// CHECK: error: format ambiguity caused by `:` literal found after attribute `attr` +def AmbiguousTypeE : TestFormat_Op<[{ + $attr ($a^)? `:` attr-dict type($a) +}]>, Arguments<(ins AnyAttr:$attr, Optional:$a)>; + +// Test attribute followed by optional group with literals and then colon. +// CHECK: error: format ambiguity caused by `:` literal found after attribute `attr` +def AmbiguousTypeF : TestFormat_Op<[{ + $attr (`(` $a^ `)`)? `:` attr-dict (`(` type($a)^ `)`)? +}]>, Arguments<(ins AnyAttr:$attr, Optional:$a)>; + +// Test attribute followed by optional group with else group. +// CHECK: error: format ambiguity caused by `:` literal found after attribute `attr` +def AmbiguousTypeG : TestFormat_Op<[{ + $attr (`(` $a^ `)`) : (`foo`)? `:` attr-dict (`(` type($a)^ `)`)? +}]>, Arguments<(ins AnyAttr:$attr, Optional:$a)>; + +// Test attribute followed by optional group with colon. +// CHECK: error: format ambiguity caused by `:` literal found after attribute `attr` +def AmbiguousTypeH : TestFormat_Op<[{ + $attr (` ` `:` $a^ `)`)? attr-dict (`(` type($a)^ `)`)? +}]>, Arguments<(ins AnyAttr:$attr, Optional:$a)>; + +// Test attribute followed by optional group with colon in else group. +// CHECK: error: format ambiguity caused by `:` literal found after attribute `attr` +def AmbiguousTypeI : TestFormat_Op<[{ + $attr (`(` $a^ `)`) : (`:`)? attr-dict (`(` type($a)^ `)`)? +}]>, Arguments<(ins AnyAttr:$attr, Optional:$a)>; + +// Test attribute followed by two optional groups and then a colon. +// CHECK: error: format ambiguity caused by `:` literal found after attribute `attr` +def AmbiguousTypeJ : TestFormat_Op<[{ + $attr (`(` $a^ type($a) `)`) : (`foo`)? ` ` attr-dict (`(` $b^ type($b) `)`)? + `:` +}], [AttrSizedOperandSegments]>, + Arguments<(ins AnyAttr:$attr, Optional:$a, Optional:$b)>; + +// Test attribute followed by two optional groups and then a colon in the else +// group. +// CHECK: error: format ambiguity caused by `:` literal found after attribute `attr` +def AmbiguousTypeK : TestFormat_Op<[{ + $attr (`(` $a^ type($a) `)`) : (`foo`)? ` ` attr-dict + (`(` $b^ type($b) `)`) : (`:`)? +}], [AttrSizedOperandSegments]>, + Arguments<(ins AnyAttr:$attr, Optional:$a, Optional:$b)>; + +// Test attribute followed by two optional groups with guarded colons but then a +// colon. +// CHECK: error: format ambiguity caused by `:` literal found after attribute `attr` +def AmbiguousTypeL : TestFormat_Op<[{ + $attr (`(` $a^ `:` type($a) `)`) : (`foo` `:`)? ` ` attr-dict + (`(` $b^ `:` type($b) `)`) : (`foo` `:`)? `:` +}], [AttrSizedOperandSegments]>, + Arguments<(ins AnyAttr:$attr, Optional:$a, Optional:$b)>; + +// Test optional attribute followed by optional groups with a colon along one +// path. +// CHECK: error: format ambiguity caused by `:` literal found after attribute `attr` +def AmbiguousTypeM : TestFormat_Op<[{ + (`(` $attr^ ` `)? (`(` $a^ `:` type($a) `)`) : (`foo` `:`)? ` ` attr-dict + (`(` $b^ `:` type($b) `)`) : (`foo` `:`)? `:` +}], [AttrSizedOperandSegments]>, + Arguments<(ins OptionalAttr:$attr, Optional:$a, + Optional:$b)>; + +// Test optional attribute followed by optional groups with a colon along one +// path inside an optional group. +// CHECK: error: format ambiguity caused by `:` literal found after attribute `attr` +def AmbiguousTypeN : TestFormat_Op<[{ + (`(` $attr^ ` `)? (`(` $a^ `:` type($a) `)`) : (`foo` `:`)? ` ` attr-dict + (`(` $b^ `:` type($b) `)`) : (`:`)? +}], [AttrSizedOperandSegments]>, + Arguments<(ins OptionalAttr:$attr, Optional:$a, + Optional:$b)>; + +// CHECK-NOT: error + +// Test attribute followed by two optional groups with guarded colons. +def ValidTypeA : TestFormat_Op<[{ + $attr (`(` $a^ `:` type($a) `)`) : (`foo` `:`)? ` ` attr-dict + (`(` $b^ `:` type($b) `)`) : (`foo` `:`)? ` ` `(` `:` +}], [AttrSizedOperandSegments]>, + Arguments<(ins AnyAttr:$attr, Optional:$a, Optional:$b)>; + +// Test optional attribute followed by two optional groups with guarded colons. +def ValidTypeB : TestFormat_Op<[{ + (`(` $attr^ ` `)? (`(` $a^ `:` type($a) `)`) : (`foo` `:`)? ` ` attr-dict + (`(` $b^ `:` type($b) `)`) : (`foo` `:`)? ` ` `(` `:` +}], [AttrSizedOperandSegments]>, + Arguments<(ins OptionalAttr:$attr, Optional:$a, + Optional:$b)>; + +// Test optional attribute guarded colon along within segment. +def ValidTypeC : TestFormat_Op<[{ + (`(` $attr^ `)`) : (`:`)? attr-dict `:` +}]>, Arguments<(ins OptionalAttr:$attr)>; diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -10,11 +10,8 @@ #include "FormatGen.h" #include "OpClass.h" #include "mlir/Support/LLVM.h" -#include "mlir/Support/LogicalResult.h" #include "mlir/TableGen/Class.h" #include "mlir/TableGen/Format.h" -#include "mlir/TableGen/GenInfo.h" -#include "mlir/TableGen/Interfaces.h" #include "mlir/TableGen/Operator.h" #include "mlir/TableGen/Trait.h" #include "llvm/ADT/MapVector.h" @@ -22,11 +19,11 @@ #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/StringExtras.h" -#include "llvm/ADT/StringSwitch.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Signals.h" -#include "llvm/TableGen/Error.h" +#include "llvm/Support/SourceMgr.h" #include "llvm/TableGen/Record.h" +#include #define DEBUG_TYPE "mlir-tblgen-opformatgen" @@ -2199,14 +2196,17 @@ Optional transformer; }; - using ElementsItT = ArrayRef::iterator; - /// Verify the state of operation attributes within the format. LogicalResult verifyAttributes(SMLoc loc, ArrayRef elements); + /// Verify the attribute elements at the back of the given stack of iterators. - LogicalResult verifyAttributes( - SMLoc loc, - SmallVectorImpl> &iteratorStack); + LogicalResult verifyAttributeColonType(SMLoc loc, + ArrayRef elements); + + /// Verify the assembly format elements to ensure there are no ambiguities + /// with order-independent lists. + LogicalResult verifyOIListElements(SMLoc loc, + ArrayRef elements); /// Verify the state of operation operands within the format. LogicalResult @@ -2224,9 +2224,6 @@ /// Verify the state of operation successors within the format. LogicalResult verifySuccessors(SMLoc loc); - LogicalResult verifyOIListElements(SMLoc loc, - ArrayRef elements); - /// Given the values of an `AllTypesMatch` trait, check for inferable type /// resolution. void handleAllTypesMatchConstraint( @@ -2325,8 +2322,7 @@ if (failed(verifyAttributes(loc, elements)) || failed(verifyResults(loc, variableTyResolver)) || failed(verifyOperands(loc, variableTyResolver)) || - failed(verifyRegions(loc)) || failed(verifySuccessors(loc)) || - failed(verifyOIListElements(loc, elements))) + failed(verifyRegions(loc)) || failed(verifySuccessors(loc))) return failure(); // Collect the set of used attributes in the format. @@ -2341,11 +2337,12 @@ // type. The attribute grammar contains an optional trailing colon type, which // can lead to unexpected and generally unintended behavior. Given that, it is // better to just error out here instead. - SmallVector, 1> iteratorStack; - iteratorStack.emplace_back(elements.begin(), elements.end()); - while (!iteratorStack.empty()) - if (failed(verifyAttributes(loc, iteratorStack))) - return ::failure(); + if (failed(verifyAttributeColonType(loc, elements))) + return failure(); + + // Verify the format elements with respect to order-independent lists. + if (failed(verifyOIListElements(loc, elements))) + return failure(); // Check for VariadicOfVariadic variables. The segment attribute of those // variables will be infered. @@ -2358,60 +2355,170 @@ return success(); } -/// Verify the attribute elements at the back of the given stack of iterators. -LogicalResult OpFormatParser::verifyAttributes( - SMLoc loc, - SmallVectorImpl> &iteratorStack) { - auto &stackIt = iteratorStack.back(); - ElementsItT &it = stackIt.first, e = stackIt.second; - while (it != e) { - FormatElement *element = *(it++); - - // Traverse into optional groups. - if (auto *optional = dyn_cast(element)) { - auto thenElements = optional->getThenElements(); - iteratorStack.emplace_back(thenElements.begin(), thenElements.end()); - - auto elseElements = optional->getElseElements(); - iteratorStack.emplace_back(elseElements.begin(), elseElements.end()); - return success(); - } - - // We are checking for an attribute element followed by a `:`, so there is - // no need to check the end. - if (it == e && iteratorStack.size() == 1) - break; - // Check for an attribute with a constant type builder, followed by a `:`. - auto *prevAttr = dyn_cast(element); - if (!prevAttr || prevAttr->getTypeBuilder()) +/// Given a segment of a parsing path, check whether the first non-optional +/// element is a colon. Returns failure if colon was found, meaning the path is +/// invalid. Returns success is the path is proved to be valid, and returns None +/// if iteration reached the end of the path. +static Optional checkPathSegmentForColon( + function_ref emitError, StringRef attrName, + iterator_range::iterator> path) { + for (FormatElement *element : path) { + // Skip whitespaces and attribute dictionaries as they are optionally + // parsed. + if (isa(element)) continue; + // If we encounter anything other than `:`, this parsing path is valid. + auto *literal = dyn_cast(element); + if (!literal || literal->getSpelling() != ":") + return success(); + return emitError( + llvm::formatv("format ambiguity caused by `:` literal found after " + "attribute `{0}` which does not have a buildable type", + attrName)); + } + return llvm::None; +} - // Check the next iterator within the stack for literal elements. - for (auto &nextItPair : iteratorStack) { - ElementsItT nextIt = nextItPair.first, nextE = nextItPair.second; - for (; nextIt != nextE; ++nextIt) { - // Skip any trailing whitespace, attribute dictionaries, or optional - // groups. - if (isa(*nextIt) || - isa(*nextIt) || isa(*nextIt)) +// We need to check for an attribute followed by a colon literal, skipping any +// optional elements, along all possible parsing paths. This function will +// "instantiate" optional groups into a series of path segments. Consider the +// following assembly format: +// +// (`foo` $attr^)? attr-dict (`(` $a^ `)`) : (`foo`)? (`:` $b^)? +// +// When the optional groups are instantiated, the parsing paths will look like: +// +// [`foo`, $attr] [attr-dict] [`(`, $a, `)`] [`:`, $b] +// [] [`foo`] [] +// [] +// +// This function will then find any attribute elements among the path segments +// and check until the end of that path segment for a colon. If it reaches the +// end, it will check the next group of segments for a possible ambiguity. If +// one is found, the function returns failure. If it reaches the end of any +// segment, it repeats on the next group. +// +// TODO: Make this less complicated. +LogicalResult OpFormatParser::verifyAttributeColonType( + SMLoc loc, + ArrayRef elements) { + using ParsingBranch = SmallVector, 2>; + std::vector parsingPaths; + parsingPaths.reserve(elements.size()); + + auto *it = elements.begin(), *e = elements.end(); + auto *start = it; + while (it!= e) { + auto *optional = dyn_cast(*it); + // Iterate until an optional element is found. + if (!optional) { + ++it; + continue; + } + // Push the current elements. + if (it != start) { + parsingPaths.emplace_back().push_back( + {start, static_cast(std::distance(start, it))}); + } + // Skip the optional element. + ++it; + start = it; + // Push its children. + SmallVectorImpl> &path = + parsingPaths.emplace_back(); + path.push_back(optional->getThenElements()); + if (!optional->getElseElements().empty()) + path.push_back(optional->getElseElements()); + // Push an empty branch. + path.push_back({}); + } + // Since we iterated to the end, push any remaining elements. + if (it != start) { + parsingPaths.emplace_back().push_back( + {start, static_cast(std::distance(start, it))}); + } + + assert(!parsingPaths.empty() && !parsingPaths.front().empty() && + "unexpected empty assembly format"); + // Iterate until we find an attribute element. Then, starting from that + // position, check until the end of the current path segment, then move on to + // the next segments. Check all branches of subsequent segments. + for (auto pathIt = parsingPaths.begin(), pathEnd = parsingPaths.end(); + pathIt != pathEnd; ++pathIt) { + for (auto *branchIt = pathIt->begin(), *branchEnd = pathIt->end(); + branchIt != branchEnd; ++branchIt) { + for (auto *elIt = branchIt->begin(), *elEnd = branchIt->end(); + elIt != elEnd; ++elIt) { + // Check attributes that lack a type builder. + auto *attr = dyn_cast(*elIt); + if (!attr || attr->getTypeBuilder()) continue; + auto checkPath = + [&](iterator_range::iterator> path) { + return checkPathSegmentForColon( + [&](const Twine &msg) { return emitError(loc, msg); }, + attr->getVar()->name, path); + }; + if (Optional result = + checkPath(llvm::make_range(std::next(elIt), elEnd))) + return *result; + // Check all possible parsing paths starting from the next segment. + for (auto secondPathIt = std::next(pathIt); secondPathIt != pathEnd; + ++secondPathIt) { + bool reachedEnd = false; + for (ArrayRef path : *secondPathIt) { + if (Optional result = checkPath(path)) { + if (failed(*result)) + return failure(); + continue; + } + reachedEnd = true; + } + if (!reachedEnd) + return success(); + } + } + } + } - // We are only interested in `:` literals. - auto *literal = dyn_cast(*nextIt); - if (!literal || literal->getSpelling() != ":") - break; + return success(); +} - // TODO: Use the location of the literal element itself. +LogicalResult +OpFormatParser::verifyOIListElements(SMLoc loc, + ArrayRef elements) { + // Check that all of the successors are within the format. + SmallVector prohibitedLiterals; + for (FormatElement *it : elements) { + if (auto *oilist = dyn_cast(it)) { + if (!prohibitedLiterals.empty()) { + // We just saw an oilist element in last iteration. Literals should not + // match. + for (LiteralElement *literal : oilist->getLiteralElements()) { + if (find(prohibitedLiterals, literal->getSpelling()) != + prohibitedLiterals.end()) { + return emitError( + loc, "format ambiguity because " + literal->getSpelling() + + " is used in two adjacent oilist elements."); + } + } + } + for (LiteralElement *literal : oilist->getLiteralElements()) + prohibitedLiterals.push_back(literal->getSpelling()); + } else if (auto *literal = dyn_cast(it)) { + if (find(prohibitedLiterals, literal->getSpelling()) != + prohibitedLiterals.end()) { return emitError( - loc, llvm::formatv("format ambiguity caused by `:` literal found " - "after attribute `{0}` which does not have " - "a buildable type", - prevAttr->getVar()->name)); + loc, + "format ambiguity because " + literal->getSpelling() + + " is used both in oilist element and the adjacent literal."); } + prohibitedLiterals.clear(); + } else { + prohibitedLiterals.clear(); } } - iteratorStack.pop_back(); return success(); } @@ -2546,43 +2653,6 @@ return success(); } -LogicalResult -OpFormatParser::verifyOIListElements(SMLoc loc, - ArrayRef elements) { - // Check that all of the successors are within the format. - SmallVector prohibitedLiterals; - for (FormatElement *it : elements) { - if (auto *oilist = dyn_cast(it)) { - if (!prohibitedLiterals.empty()) { - // We just saw an oilist element in last iteration. Literals should not - // match. - for (LiteralElement *literal : oilist->getLiteralElements()) { - if (find(prohibitedLiterals, literal->getSpelling()) != - prohibitedLiterals.end()) { - return emitError( - loc, "format ambiguity because " + literal->getSpelling() + - " is used in two adjacent oilist elements."); - } - } - } - for (LiteralElement *literal : oilist->getLiteralElements()) - prohibitedLiterals.push_back(literal->getSpelling()); - } else if (auto *literal = dyn_cast(it)) { - if (find(prohibitedLiterals, literal->getSpelling()) != - prohibitedLiterals.end()) { - return emitError( - loc, - "format ambiguity because " + literal->getSpelling() + - " is used both in oilist element and the adjacent literal."); - } - prohibitedLiterals.clear(); - } else { - prohibitedLiterals.clear(); - } - } - return success(); -} - void OpFormatParser::handleAllTypesMatchConstraint( ArrayRef values, llvm::StringMap &variableTyResolver) {