diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOpsBase.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOpsBase.td --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOpsBase.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOpsBase.td @@ -20,6 +20,7 @@ CPred<"$_self.isa()">, "AffineMap attribute"> { let storageType = [{ AffineMapAttr }]; let returnType = [{ AffineMap }]; + let valueType = Index; let constBuilderCall = "AffineMapAttr::get($0)"; } diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -319,7 +319,8 @@ def AnyType : Type, "any type">; // None type -def NoneType : Type()">, "none type">; +def NoneType : Type()">, "none type">, + BuildableType<"$_builder.getType()">; // Any type from the given list class AnyTypeOf allowedTypes, string description = ""> : Type< @@ -835,6 +836,7 @@ def BoolAttr : Attr()">, "bool attribute"> { let storageType = [{ BoolAttr }]; let returnType = [{ bool }]; + let valueType = I1; let constBuilderCall = "$_builder.getBoolAttr($0)"; } @@ -942,11 +944,18 @@ let constBuilderCall = "$_builder.getStringAttr(\"$0\")"; let storageType = [{ StringAttr }]; let returnType = [{ StringRef }]; + let valueType = NoneType; } def StrAttr : StringBasedAttr()">, "string attribute">; +// String attribute that has a specific value type. +class TypedStrAttr : StringBasedAttr()">, + "string attribute"> { + let valueType = ty; +} + // Base class for attributes containing types. Example: // def IntTypeAttr : TypeAttrBase<"IntegerType", "integer type attribute"> // defines a type attribute containing an integer type. @@ -957,6 +966,7 @@ description> { let storageType = [{ TypeAttr }]; let returnType = retType; + let valueType = NoneType; let convertFromStorage = "$_self.getValue().cast<" # retType # ">()"; } @@ -970,6 +980,7 @@ let constBuilderCall = "$_builder.getUnitAttr()"; let convertFromStorage = "$_self != nullptr"; let returnType = "bool"; + let valueType = NoneType; let isOptional = 1; } @@ -1166,6 +1177,7 @@ "dictionary of named attribute values"> { let storageType = [{ DictionaryAttr }]; let returnType = [{ DictionaryAttr }]; + let valueType = NoneType; let convertFromStorage = "$_self"; } @@ -1285,6 +1297,7 @@ Attr { let storageType = [{ ArrayAttr }]; let returnType = [{ ArrayAttr }]; + let valueType = NoneType; let convertFromStorage = "$_self"; } @@ -1364,6 +1377,7 @@ "symbol reference attribute"> { let storageType = [{ SymbolRefAttr }]; let returnType = [{ SymbolRefAttr }]; + let valueType = NoneType; let constBuilderCall = "$_builder.getSymbolRefAttr($0)"; let convertFromStorage = "$_self"; } @@ -1371,6 +1385,7 @@ "flat symbol reference attribute"> { let storageType = [{ FlatSymbolRefAttr }]; let returnType = [{ StringRef }]; + let valueType = NoneType; let constBuilderCall = "$_builder.getSymbolRefAttr($0)"; let convertFromStorage = "$_self.getValue()"; } diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir --- a/mlir/test/IR/attribute.mlir +++ b/mlir/test/IR/attribute.mlir @@ -247,7 +247,7 @@ // CHECK-LABEL: func @string_attr_custom_type func @string_attr_custom_type() { // CHECK: "string_data" : !foo.string - test.string_attr_with_type "string_data" + test.string_attr_with_type "string_data" : !foo.string return } diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -158,15 +158,8 @@ let arguments = (ins TypeArrayAttr:$attr); } def TypeStringAttrWithTypeOp : TEST_Op<"string_attr_with_type"> { - let arguments = (ins StrAttr:$attr); - let printer = [{ p << getAttr("attr"); }]; - let parser = [{ - Attribute attr; - Type stringType = OpaqueType::get(Identifier::get("foo", - result.getContext()), "string", - result.getContext()); - return parser.parseAttribute(attr, stringType, "attr", result.attributes); - }]; + let arguments = (ins TypedStrAttr:$attr); + let assemblyFormat = "$attr attr-dict"; } def StrCaseA: StrEnumAttrCase<"A">; diff --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td --- a/mlir/test/mlir-tblgen/op-format-spec.td +++ b/mlir/test/mlir-tblgen/op-format-spec.td @@ -1,4 +1,4 @@ -// RUN: mlir-tblgen -gen-op-decls -asmformat-error-is-fatal=false -I %S/../../include %s 2>&1 | FileCheck %s --dump-input-on-failure +// RUN: mlir-tblgen -gen-op-decls -asmformat-error-is-fatal=false -I %S/../../include %s -o=%t 2>&1 | FileCheck %s --dump-input-on-failure // This file contains tests for the specification of the declarative op format. @@ -275,6 +275,21 @@ }]> { let successors = (successor AnySuccessor:$successor); } +// CHECK: error: format ambiguity caused by `:` literal found after attribute `attr` which does not have a buildable type +def VariableInvalidH : TestFormat_Op<"variable_invalid_h", [{ + $attr `:` attr-dict +}]>, Arguments<(ins ElementsAttr:$attr)>; +// CHECK: error: format ambiguity caused by `:` literal found after attribute `attr` which does not have a buildable type +def VariableInvalidI : TestFormat_Op<"variable_invalid_i", [{ + (`foo` $attr^)? `:` attr-dict +}]>, Arguments<(ins OptionalAttr:$attr)>; +// CHECK-NOT: error: +def VariableInvalidJ : TestFormat_Op<"variable_invalid_j", [{ + $attr `:` attr-dict +}]>, Arguments<(ins OptionalAttr:$attr)>; +def VariableInvalidK : TestFormat_Op<"variable_invalid_k", [{ + (`foo` $attr^)? `:` attr-dict +}]>, Arguments<(ins OptionalAttr:$attr)>; //===----------------------------------------------------------------------===// // Coverage Checks diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -92,13 +92,23 @@ } const VarT *getVar() { return var; } -private: +protected: const VarT *var; }; /// This class represents a variable that refers to an attribute argument. -using AttributeVariable = - VariableElement; +struct AttributeVariable + : public VariableElement { + using VariableElement::VariableElement; + + /// Return the constant builder call for the type of this attribute, or None + /// if it doesn't have one. + Optional getTypeBuilder() const { + Optional attrType = var->attr.getValueType(); + return attrType ? attrType->getBuilderCall() : llvm::None; + } +}; /// This class represents a variable that refers to an operand argument. using OperandVariable = @@ -574,11 +584,9 @@ // If this attribute has a buildable type, use that when parsing the // attribute. std::string attrTypeStr; - if (Optional attrType = var->attr.getValueType()) { - if (Optional typeBuilder = attrType->getBuilderCall()) { - llvm::raw_string_ostream os(attrTypeStr); - os << ", " << tgfmt(*typeBuilder, &attrTypeCtx); - } + if (Optional typeBuilder = attr->getTypeBuilder()) { + llvm::raw_string_ostream os(attrTypeStr); + os << ", " << tgfmt(*typeBuilder, &attrTypeCtx); } body << formatv(attrParserCode, var->attr.getStorageType(), var->name, @@ -932,8 +940,7 @@ } // Elide the attribute type if it is buildable. - Optional attrType = var->attr.getValueType(); - if (attrType && attrType->getBuilderCall()) + if (attr->getTypeBuilder()) body << " p.printAttributeWithoutType(" << var->name << "Attr());\n"; else body << " p.printAttribute(" << var->name << "Attr());\n"; @@ -1234,6 +1241,22 @@ Optional transformer; }; + /// Verify the state of operation attributes within the format. + LogicalResult verifyAttributes(llvm::SMLoc loc); + + /// Verify the state of operation operands within the format. + LogicalResult + verifyOperands(llvm::SMLoc loc, + llvm::StringMap &variableTyResolver); + + /// Verify the state of operation results within the format. + LogicalResult + verifyResults(llvm::SMLoc loc, + llvm::StringMap &variableTyResolver); + + /// Verify the state of operation successors within the format. + LogicalResult verifySuccessors(llvm::SMLoc loc); + /// Given the values of an `AllTypesMatch` trait, check for inferable type /// resolution. void handleAllTypesMatchConstraint( @@ -1357,37 +1380,86 @@ } } - // Check that all of the result types can be inferred. - auto &buildableTypes = fmt.buildableTypes; - if (!fmt.allResultTypes) { - for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) { - if (seenResultTypes.test(i)) - continue; + // Verify the state of the various operation components. + if (failed(verifyAttributes(loc)) || + failed(verifyResults(loc, variableTyResolver)) || + failed(verifyOperands(loc, variableTyResolver)) || + failed(verifySuccessors(loc))) + return failure(); - // Check to see if we can infer this type from another variable. - auto varResolverIt = variableTyResolver.find(op.getResultName(i)); - if (varResolverIt != variableTyResolver.end()) { - fmt.resultTypes[i].setVariable(varResolverIt->second.type, - varResolverIt->second.transformer); - continue; + // Check to see if we are formatting all of the operands. + fmt.allOperands = llvm::any_of(fmt.elements, [](auto &elt) { + return isa(elt.get()); + }); + return success(); +} + +LogicalResult FormatParser::verifyAttributes(llvm::SMLoc loc) { + // Check that there are no `:` literals after an attribute without a constant + // type. The attribute grammar contains an optional trailing colon type, which + // can lead to unexpected and generally unintended behavior. Given that, it is + // better to just error out here instead. + using ElementsIterT = llvm::pointee_iterator< + std::vector>::const_iterator>; + SmallVector, 1> iteratorStack; + iteratorStack.emplace_back(fmt.elements.begin(), fmt.elements.end()); + while (!iteratorStack.empty()) { + auto &stackIt = iteratorStack.back(); + ElementsIterT &it = stackIt.first, e = stackIt.second; + while (it != e) { + Element *element = &*(it++); + + // Traverse into optional groups. + if (auto *optional = dyn_cast(element)) { + auto elements = optional->getElements(); + iteratorStack.emplace_back(elements.begin(), elements.end()); + break; } - // If the result is not variadic, allow for the case where the type has a - // builder that we can use. - NamedTypeConstraint &result = op.getResult(i); - Optional builder = result.constraint.getBuilderCall(); - if (!builder || result.constraint.isVariadic()) { - return emitError(loc, "format missing instance of result #" + Twine(i) + - "('" + result.name + "') type"); + // We are checking for an attribute element followed by a `:`, so there is + // no need to check the end. + if (it == e && iteratorStack.size() == 1) + break; + + // Check for an attribute with a constant type builder, followed by a `:`. + auto *prevAttr = dyn_cast(element); + if (!prevAttr || prevAttr->getTypeBuilder()) + continue; + + // Check the next iterator within the stack for literal elements. + for (auto &nextItPair : iteratorStack) { + ElementsIterT nextIt = nextItPair.first, nextE = nextItPair.second; + for (; nextIt != nextE; ++nextIt) { + // Skip any trailing optional groups or attribute dictionaries. + if (isa(*nextIt) || isa(*nextIt)) + continue; + + // We are only interested in `:` literals. + auto *literal = dyn_cast(&*nextIt); + if (!literal || literal->getLiteral() != ":") + break; + + // TODO: Use the location of the literal element itself. + return emitError( + loc, llvm::formatv("format ambiguity caused by `:` literal found " + "after attribute `{0}` which does not have " + "a buildable type", + prevAttr->getVar()->name)); + } } - // Note in the format that this result uses the custom builder. - auto it = buildableTypes.insert({*builder, buildableTypes.size()}); - fmt.resultTypes[i].setBuilderIdx(it.first->second); } + if (it == e) + iteratorStack.pop_back(); } + return success(); +} +LogicalResult FormatParser::verifyOperands( + llvm::SMLoc loc, + llvm::StringMap &variableTyResolver) { // Check that all of the operands are within the format, and their types can // be inferred. + auto &buildableTypes = fmt.buildableTypes; for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) { NamedTypeConstraint &operand = op.getOperand(i); @@ -1419,22 +1491,57 @@ auto it = buildableTypes.insert({*builder, buildableTypes.size()}); fmt.operandTypes[i].setBuilderIdx(it.first->second); } + return success(); +} - // Check that all of the successors are within the format. - if (!hasAllSuccessors) { - for (unsigned i = 0, e = op.getNumSuccessors(); i != e; ++i) { - const NamedSuccessor &successor = op.getSuccessor(i); - if (!seenSuccessors.count(&successor)) { - return emitError(loc, "format missing instance of successor #" + - Twine(i) + "('" + successor.name + "')"); - } +LogicalResult FormatParser::verifyResults( + llvm::SMLoc loc, + llvm::StringMap &variableTyResolver) { + // If we format all of the types together, there is nothing to check. + if (fmt.allResultTypes) + return success(); + + // Check that all of the result types can be inferred. + auto &buildableTypes = fmt.buildableTypes; + for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) { + if (seenResultTypes.test(i)) + continue; + + // Check to see if we can infer this type from another variable. + auto varResolverIt = variableTyResolver.find(op.getResultName(i)); + if (varResolverIt != variableTyResolver.end()) { + fmt.resultTypes[i].setVariable(varResolverIt->second.type, + varResolverIt->second.transformer); + continue; } + + // If the result is not variadic, allow for the case where the type has a + // builder that we can use. + NamedTypeConstraint &result = op.getResult(i); + Optional builder = result.constraint.getBuilderCall(); + if (!builder || result.constraint.isVariadic()) { + return emitError(loc, "format missing instance of result #" + Twine(i) + + "('" + result.name + "') type"); + } + // Note in the format that this result uses the custom builder. + auto it = buildableTypes.insert({*builder, buildableTypes.size()}); + fmt.resultTypes[i].setBuilderIdx(it.first->second); } + return success(); +} - // Check to see if we are formatting all of the operands. - fmt.allOperands = llvm::any_of(fmt.elements, [](auto &elt) { - return isa(elt.get()); - }); +LogicalResult FormatParser::verifySuccessors(llvm::SMLoc loc) { + // Check that all of the successors are within the format. + if (hasAllSuccessors) + return success(); + + for (unsigned i = 0, e = op.getNumSuccessors(); i != e; ++i) { + const NamedSuccessor &successor = op.getSuccessor(i); + if (!seenSuccessors.count(&successor)) { + return emitError(loc, "format missing instance of successor #" + + Twine(i) + "('" + successor.name + "')"); + } + } return success(); }