diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2247,6 +2247,12 @@ }]; } +def FormatLiteralFollowingOptionalGroup + : TEST_Op<"format_literal_following_optional_group"> { + let arguments = (ins TypeAttr:$type, OptionalAttr:$value); + let assemblyFormat = "(`(` $value^ `)`)? `:` $type attr-dict"; +} + //===----------------------------------------------------------------------===// // AllTypesMatch type inference diff --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td --- a/mlir/test/mlir-tblgen/op-format-spec.td +++ b/mlir/test/mlir-tblgen/op-format-spec.td @@ -505,18 +505,6 @@ }]> { let successors = (successor AnySuccessor:$successor); } -// CHECK: error: format ambiguity caused by `:` literal found after attribute `attr` which does not have a buildable type -def VariableInvalidH : TestFormat_Op<[{ - $attr `:` attr-dict -}]>, Arguments<(ins ElementsAttr:$attr)>; -// CHECK: error: format ambiguity caused by `:` literal found after attribute `attr` which does not have a buildable type -def VariableInvalidI : TestFormat_Op<[{ - (`foo` $attr^)? `:` attr-dict -}]>, Arguments<(ins OptionalAttr:$attr)>; -// CHECK: error: format ambiguity caused by `:` literal found after attribute `attr` which does not have a buildable type -def VariableInvalidJ : TestFormat_Op<[{ - $attr ` ` `:` attr-dict -}]>, Arguments<(ins ElementsAttr:$attr)>; // CHECK: error: region 'region' is already bound def VariableInvalidK : TestFormat_Op<[{ $region $region attr-dict diff --git a/mlir/test/mlir-tblgen/op-format-verify.td b/mlir/test/mlir-tblgen/op-format-verify.td new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-tblgen/op-format-verify.td @@ -0,0 +1,166 @@ +// RUN: mlir-tblgen -gen-op-decls -asmformat-error-is-fatal=false -I %S/../../include %s -o=%t 2>&1 | FileCheck %s + +include "mlir/IR/OpBase.td" + +def TestDialect : Dialect { + let name = "test"; +} +class TestFormat_Op traits = []> + : Op { + let assemblyFormat = fmt; +} + +//===----------------------------------------------------------------------===// +// Format ambiguity caused by attribute followed by colon literal +//===----------------------------------------------------------------------===// + +// Test attribute followed by a colon. +// CHECK: error: format ambiguity caused by `:` literal found after attribute `attr` +def AmbiguousTypeA : TestFormat_Op<[{ + $attr `:` attr-dict +}]>, Arguments<(ins AnyAttr:$attr)>; + +// Test optional attribute followed by colon. +// CHECK: error: format ambiguity caused by `:` literal found after attribute `attr` +def AmbiguousTypeB : TestFormat_Op<[{ + (`foo` $attr^)? `:` attr-dict +}]>, Arguments<(ins OptionalAttr:$attr)>; + +// Test attribute followed by whitespace and then colon. +// CHECK: error: format ambiguity caused by `:` literal found after attribute `attr` +def AmbiguousTypeC : TestFormat_Op<[{ + $attr ` ` `:` attr-dict +}]>, Arguments<(ins AnyAttr:$attr)>; + +// Test attribute followed by optional dictionary and then colon. +// CHECK: error: format ambiguity caused by `:` literal found after attribute `attr` +def AmbiguousTypeD : TestFormat_Op<[{ + $attr attr-dict `:` +}]>, Arguments<(ins AnyAttr:$attr)>; + +// Test attribute followed by optional group and then colon. +// CHECK: error: format ambiguity caused by `:` literal found after attribute `attr` +def AmbiguousTypeE : TestFormat_Op<[{ + $attr ($a^)? `:` attr-dict type($a) +}]>, Arguments<(ins AnyAttr:$attr, Optional:$a)>; + +// Test attribute followed by optional group with literals and then colon. +// CHECK: error: format ambiguity caused by `:` literal found after attribute `attr` +def AmbiguousTypeF : TestFormat_Op<[{ + $attr (`(` $a^ `)`)? `:` attr-dict (`(` type($a)^ `)`)? +}]>, Arguments<(ins AnyAttr:$attr, Optional:$a)>; + +// Test attribute followed by optional group with else group. +// CHECK: error: format ambiguity caused by `:` literal found after attribute `attr` +def AmbiguousTypeG : TestFormat_Op<[{ + $attr (`(` $a^ `)`) : (`foo`)? `:` attr-dict (`(` type($a)^ `)`)? +}]>, Arguments<(ins AnyAttr:$attr, Optional:$a)>; + +// Test attribute followed by optional group with colon. +// CHECK: error: format ambiguity caused by `:` literal found after attribute `attr` +def AmbiguousTypeH : TestFormat_Op<[{ + $attr (`:` $a^ `)`)? attr-dict (`(` type($a)^ `)`)? +}]>, Arguments<(ins AnyAttr:$attr, Optional:$a)>; + +// Test attribute followed by optional group with colon in else group. +// CHECK: error: format ambiguity caused by `:` literal found after attribute `attr` +def AmbiguousTypeI : TestFormat_Op<[{ + $attr (`(` $a^ `)`) : (`:`)? attr-dict (`(` type($a)^ `)`)? +}]>, Arguments<(ins AnyAttr:$attr, Optional:$a)>; + +// Test attribute followed by two optional groups and then a colon. +// CHECK: error: format ambiguity caused by `:` literal found after attribute `attr` +def AmbiguousTypeJ : TestFormat_Op<[{ + $attr (`(` $a^ type($a) `)`) : (`foo`)? ` ` attr-dict (`(` $b^ type($b) `)`)? + `:` +}], [AttrSizedOperandSegments]>, + Arguments<(ins AnyAttr:$attr, Optional:$a, Optional:$b)>; + +// Test attribute followed by two optional groups and then a colon in the else +// group. +// CHECK: error: format ambiguity caused by `:` literal found after attribute `attr` +def AmbiguousTypeK : TestFormat_Op<[{ + $attr (`(` $a^ type($a) `)`) : (`foo`)? ` ` attr-dict + (`(` $b^ type($b) `)`) : (`:`)? +}], [AttrSizedOperandSegments]>, + Arguments<(ins AnyAttr:$attr, Optional:$a, Optional:$b)>; + +// Test attribute followed by two optional groups with guarded colons but then a +// colon. +// CHECK: error: format ambiguity caused by `:` literal found after attribute `attr` +def AmbiguousTypeL : TestFormat_Op<[{ + $attr (`(` $a^ `:` type($a) `)`) : (`foo` `:`)? ` ` attr-dict + (`(` $b^ `:` type($b) `)`) : (`foo` `:`)? `:` +}], [AttrSizedOperandSegments]>, + Arguments<(ins AnyAttr:$attr, Optional:$a, Optional:$b)>; + +// Test optional attribute followed by optional groups with a colon along one +// path. +// CHECK: error: format ambiguity caused by `:` literal found after attribute `attr` +def AmbiguousTypeM : TestFormat_Op<[{ + (`(` $attr^ ` `)? (`(` $a^ `:` type($a) `)`) : (`foo` `:`)? ` ` attr-dict + (`(` $b^ `:` type($b) `)`) : (`foo` `:`)? `:` +}], [AttrSizedOperandSegments]>, + Arguments<(ins OptionalAttr:$attr, Optional:$a, + Optional:$b)>; + +// Test optional attribute followed by optional groups with a colon along one +// path inside an optional group. +// CHECK: error: format ambiguity caused by `:` literal found after attribute `attr` +def AmbiguousTypeN : TestFormat_Op<[{ + (`(` $attr^ ` `)? (`(` $a^ `:` type($a) `)`) : (`foo` `:`)? ` ` attr-dict + (`(` $b^ `:` type($b) `)`) : (`:`)? +}], [AttrSizedOperandSegments]>, + Arguments<(ins OptionalAttr:$attr, Optional:$a, + Optional:$b)>; + +// Test attribute followed by optional attribute, operand, successor, region, +// and a colon. +// CHECK: error: format ambiguity caused by `:` literal found after attribute `attr` +def AmbiguousTypeO : TestFormat_Op<[{ + $attr attr-dict $a $b $c $d $e `:` +}], [AttrSizedOperandSegments]> { + let arguments = (ins AnyAttr:$attr, OptionalAttr:$a, + Optional:$b, Variadic:$c); + let successors = (successor VariadicSuccessor:$d); + let regions = (region VariadicRegion:$e); +} + +// Test two attributes, where the second one is ambiguous. +// CHECK: error: format ambiguity caused by `:` literal found after attribute `b` +def AmbiguousTypeP : TestFormat_Op<[{ + $a attr-dict `(` `:` $b (`:` $c^)? +}]>, Arguments<(ins AnyAttr:$a, AnyAttr:$b, Optional:$c)>; + +// Test two attributes, where the second one is ambiguous. +// CHECK: error: format ambiguity caused by `:` literal found after attribute `b` +def AmbiguousTypeQ : TestFormat_Op<[{ + $a attr-dict (`(` $c^ `:`)? `(` `:` $b `:` +}]>, Arguments<(ins AnyAttr:$a, AnyAttr:$b, Optional:$c)>; + +// CHECK-NOT: error + +// Test attribute followed by two optional groups with guarded colons. +def ValidTypeA : TestFormat_Op<[{ + $attr (`(` $a^ `:` type($a) `)`) : (`foo` `:`)? ` ` attr-dict + (`(` $b^ `:` type($b) `)`) : (`foo` `:`)? ` ` `(` `:` +}], [AttrSizedOperandSegments]>, + Arguments<(ins AnyAttr:$attr, Optional:$a, Optional:$b)>; + +// Test optional attribute followed by two optional groups with guarded colons. +def ValidTypeB : TestFormat_Op<[{ + (`(` $attr^ ` `)? (`(` $a^ `:` type($a) `)`) : (`foo` `:`)? ` ` attr-dict + (`(` $b^ `:` type($b) `)`) : (`foo` `:`)? ` ` `(` `:` +}], [AttrSizedOperandSegments]>, + Arguments<(ins OptionalAttr:$attr, Optional:$a, + Optional:$b)>; + +// Test optional attribute guarded colon along within segment. +def ValidTypeC : TestFormat_Op<[{ + (`(` $attr^ `)`) : (`:`)? attr-dict `:` +}]>, Arguments<(ins OptionalAttr:$attr)>; + +// Test optional group guard blocks colon. +def ValidTypeD : TestFormat_Op<[{ + $a attr-dict ($c^ `:`)? +}]>, Arguments<(ins AnyAttr:$a, Optional:$c)>; diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir --- a/mlir/test/mlir-tblgen/op-format.mlir +++ b/mlir/test/mlir-tblgen/op-format.mlir @@ -389,6 +389,9 @@ return } +// CHECK: test.format_literal_following_optional_group(5 : i32) : i32 {a} +test.format_literal_following_optional_group(5 : i32) : i32 {a} + //===----------------------------------------------------------------------===// // Format trait type inference //===----------------------------------------------------------------------===// diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -10,11 +10,8 @@ #include "FormatGen.h" #include "OpClass.h" #include "mlir/Support/LLVM.h" -#include "mlir/Support/LogicalResult.h" #include "mlir/TableGen/Class.h" #include "mlir/TableGen/Format.h" -#include "mlir/TableGen/GenInfo.h" -#include "mlir/TableGen/Interfaces.h" #include "mlir/TableGen/Operator.h" #include "mlir/TableGen/Trait.h" #include "llvm/ADT/MapVector.h" @@ -22,10 +19,9 @@ #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/StringExtras.h" -#include "llvm/ADT/StringSwitch.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Signals.h" -#include "llvm/TableGen/Error.h" +#include "llvm/Support/SourceMgr.h" #include "llvm/TableGen/Record.h" #define DEBUG_TYPE "mlir-tblgen-opformatgen" @@ -2196,14 +2192,12 @@ Optional transformer; }; - using ElementsItT = ArrayRef::iterator; - /// Verify the state of operation attributes within the format. LogicalResult verifyAttributes(SMLoc loc, ArrayRef elements); - /// Verify the attribute elements at the back of the given stack of iterators. - LogicalResult verifyAttributes( - SMLoc loc, - SmallVectorImpl> &iteratorStack); + + /// Verify that attributes elements aren't followed by colon literals. + LogicalResult verifyAttributeColonType(SMLoc loc, + ArrayRef elements); /// Verify the state of operation operands within the format. LogicalResult @@ -2338,11 +2332,8 @@ // type. The attribute grammar contains an optional trailing colon type, which // can lead to unexpected and generally unintended behavior. Given that, it is // better to just error out here instead. - SmallVector, 1> iteratorStack; - iteratorStack.emplace_back(elements.begin(), elements.end()); - while (!iteratorStack.empty()) - if (failed(verifyAttributes(loc, iteratorStack))) - return ::failure(); + if (failed(verifyAttributeColonType(loc, elements))) + return failure(); // Check for VariadicOfVariadic variables. The segment attribute of those // variables will be infered. @@ -2355,61 +2346,127 @@ return success(); } -/// Verify the attribute elements at the back of the given stack of iterators. -LogicalResult OpFormatParser::verifyAttributes( - SMLoc loc, - SmallVectorImpl> &iteratorStack) { - auto &stackIt = iteratorStack.back(); - ElementsItT &it = stackIt.first, e = stackIt.second; - while (it != e) { - FormatElement *element = *(it++); - - // Traverse into optional groups. - if (auto *optional = dyn_cast(element)) { - auto thenElements = optional->getThenElements(); - iteratorStack.emplace_back(thenElements.begin(), thenElements.end()); - auto elseElements = optional->getElseElements(); - iteratorStack.emplace_back(elseElements.begin(), elseElements.end()); - return success(); - } +/// Returns whether the single format element is optionally parsed. +static bool isOptionallyParsed(FormatElement *el) { + if (auto *attrVar = dyn_cast(el)) { + Attribute attr = attrVar->getVar()->attr; + return attr.isOptional() || attr.hasDefaultValue(); + } + if (auto *operandVar = dyn_cast(el)) { + const NamedTypeConstraint *operand = operandVar->getVar(); + return operand->isOptional() || operand->isVariadic() || + operand->isVariadicOfVariadic(); + } + if (auto *successorVar = dyn_cast(el)) + return successorVar->getVar()->isVariadic(); + if (auto *regionVar = dyn_cast(el)) + return regionVar->getVar()->isVariadic(); + return isa(el); +} - // We are checking for an attribute element followed by a `:`, so there is - // no need to check the end. - if (it == e && iteratorStack.size() == 1) - break; +/// Scan the given range of elements from the start for a colon literal, +/// skipping any optionally-parsed elements. If an optional group is +/// encountered, this function recurses into the 'then' and 'else' elements to +/// check if they are invalid. Returns `success` if the range is known to be +/// valid or `None` if scanning reached the end. +/// +/// Since the guard element of an optional group is required, this function +/// accepts an optional element pointer to mark it as required. +static Optional checkElementRangeForColon( + function_ref emitError, StringRef attrName, + iterator_range::iterator> elementRange, + FormatElement *optionalGuard = nullptr) { + for (FormatElement *element : elementRange) { + // Skip optionally parsed elements. + if (element != optionalGuard && isOptionallyParsed(element)) + continue; - // Check for an attribute with a constant type builder, followed by a `:`. - auto *prevAttr = dyn_cast(element); - if (!prevAttr || prevAttr->getTypeBuilder()) + // Recurse on optional groups. + if (auto *optional = dyn_cast(element)) { + if (Optional result = checkElementRangeForColon( + emitError, attrName, optional->getThenElements(), + // The optional group guard is required for the group. + optional->getThenElements().front())) + if (failed(*result)) + return failure(); + if (Optional result = checkElementRangeForColon( + emitError, attrName, optional->getElseElements())) + if (failed(*result)) + return failure(); + // Skip the optional group. continue; + } - // Check the next iterator within the stack for literal elements. - for (auto &nextItPair : iteratorStack) { - ElementsItT nextIt = nextItPair.first, nextE = nextItPair.second; - for (; nextIt != nextE; ++nextIt) { - // Skip any trailing whitespace, attribute dictionaries, or optional - // groups. - if (isa(*nextIt) || - isa(*nextIt) || isa(*nextIt)) - continue; + // If we encounter anything other than `:`, this range is range. + auto *literal = dyn_cast(element); + if (!literal || literal->getSpelling() != ":") + return success(); + // If we encounter `:`, the range is known to be invalid. + return emitError( + llvm::formatv("format ambiguity caused by `:` literal found after " + "attribute `{0}` which does not have a buildable type", + attrName)); + } + // Return None to indicate that we reached the end. + return llvm::None; +} - // We are only interested in `:` literals. - auto *literal = dyn_cast(*nextIt); - if (!literal || literal->getSpelling() != ":") - break; +/// For the given elements, check whether any attributes are followed by a colon +/// literal, resulting in an ambiguous assembly format. Returns a non-null +/// attribute if verification of said attribute reached the end of the range. +/// Returns null if all attribute elements are verified. +static FailureOr +verifyAttributeColon(function_ref emitError, + ArrayRef elements) { + for (auto *it = elements.begin(), *e = elements.end(); it != e; ++it) { + // The current attribute being verified. + AttributeVariable *attr = nullptr; + + if ((attr = dyn_cast(*it))) { + // Check only attributes without type builders or that are known to call + // the generic attribute parser. + if (attr->getTypeBuilder() || + !(attr->shouldBeQualified() || + attr->getVar()->attr.getStorageType() == "::mlir::Attribute")) + continue; + } else if (auto *optional = dyn_cast(*it)) { + // Recurse on optional groups. + FailureOr thenResult = + verifyAttributeColon(emitError, optional->getThenElements()); + if (failed(thenResult)) + return failure(); + FailureOr elseResult = + verifyAttributeColon(emitError, optional->getElseElements()); + if (failed(elseResult)) + return failure(); + // If either optional group has an unverified attribute, save it. + // Otherwise, move on to the next element. + if (!(attr = *thenResult) && !(attr = *elseResult)) + continue; + } else { + continue; + } - // TODO: Use the location of the literal element itself. - return emitError( - loc, llvm::formatv("format ambiguity caused by `:` literal found " - "after attribute `{0}` which does not have " - "a buildable type", - prevAttr->getVar()->name)); - } + // Verify subsequent elements for potential ambiguities. + if (Optional result = checkElementRangeForColon( + emitError, attr->getVar()->name, {std::next(it), e})) { + if (failed(*result)) + return failure(); + } else { + // Since we reached the end, return the attribute as unverified. + return attr; } } - iteratorStack.pop_back(); - return success(); + // All attribute elements are known to be verified. + return nullptr; +} + +LogicalResult +OpFormatParser::verifyAttributeColonType(SMLoc loc, + ArrayRef elements) { + return verifyAttributeColon( + [&](const Twine &msg) { return emitError(loc, msg); }, elements); } LogicalResult OpFormatParser::verifyOperands(