diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md --- a/mlir/docs/OpDefinitions.md +++ b/mlir/docs/OpDefinitions.md @@ -613,6 +613,15 @@ - Represents all of the operands of an operation. +* `ref` ( input ) + + - Represents a reference to the a variable or directive, that must have + already been resolved, that may be used as a parameter to a `custom` + directive. + - Used to pass previously parsed entities to custom directives. + - The input may be any directive or variable, aside from `functional-type` + and `custom`. + * `regions` - Represents all of the regions of an operation. @@ -631,14 +640,6 @@ - `input` must be either an operand or result [variable](#variables), the `operands` directive, or the `results` directive. -* `type_ref` ( input ) - - - Represents a reference to the type of the given input that must have - already been resolved. - - `input` must be either an operand or result [variable](#variables), the - `operands` directive, or the `results` directive. - - Used to pass previously parsed types to custom directives. - #### Literals A literal is either a keyword or punctuation surrounded by \`\`. @@ -716,6 +717,10 @@ - Single: `OpAsmParser::OperandType &` - Optional: `Optional &` - Variadic: `SmallVectorImpl &` +* Ref Directives + - A reference directive is passed to the parser using the same mapping as + the input operand. For example, a single region would be passed as a + `Region &`. * Region Variables - Single: `Region &` - Variadic: `SmallVectorImpl> &` @@ -726,10 +731,6 @@ - Single: `Type &` - Optional: `Type &` - Variadic: `SmallVectorImpl &` -* TypeRef Directives - - Single: `Type` - - Optional: `Type` - - Variadic: `const SmallVectorImpl &` * `attr-dict` Directive: `NamedAttrList &` When a variable is optional, the value should only be specified if the variable @@ -748,6 +749,10 @@ - Single: `Value` - Optional: `Value` - Variadic: `OperandRange` +* Ref Directives + - A reference directive is passed to the printer using the same mapping as + the input operand. For example, a single region would be passed as a + `Region &`. * Region Variables - Single: `Region &` - Variadic: `MutableArrayRef` @@ -758,10 +763,6 @@ - Single: `Type` - Optional: `Type` - Variadic: `TypeRange` -* TypeRef Directives - - Single: `Type` - - Optional: `Type` - - Variadic: `TypeRange` * `attr-dict` Directive: `DictionaryAttr` When a variable is optional, the provided value may be null. diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -352,6 +352,14 @@ NamedAttrList &attrs) { return parser.parseOptionalAttrDict(attrs); } +static ParseResult parseCustomDirectiveOptionalOperandRef( + OpAsmParser &parser, Optional &optOperand) { + int64_t operandCount = 0; + if (parser.parseInteger(operandCount)) + return failure(); + bool expectedOptionalOperand = operandCount == 0; + return success(expectedOptionalOperand != optOperand.hasValue()); +} //===----------------------------------------------------------------------===// // Printing @@ -417,6 +425,13 @@ DictionaryAttr attrs) { printer.printOptionalAttrDict(attrs.getValue()); } + +static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer, + Operation *op, + Value optOperand) { + printer << (optOperand ? "1" : "0"); +} + //===----------------------------------------------------------------------===// // Test IsolatedRegionOp - parse passthrough region arguments. //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1698,12 +1698,22 @@ type($result), type($optResult), type($varResults) ) custom( - type_ref($result), type_ref($optResult), type_ref($varResults) + ref(type($result)), ref(type($optResult)), ref(type($varResults)) ) attr-dict }]; } +def FormatCustomDirectiveWithOptionalOperandRef + : TEST_Op<"format_custom_directive_with_optional_operand_ref"> { + let arguments = (ins Optional:$optOperand); + let assemblyFormat = [{ + ($optOperand^)? `:` + custom(ref($optOperand)) + attr-dict + }]; +} + def FormatCustomDirectiveSuccessors : TEST_Op<"format_custom_directive_successors", [Terminator]> { let successors = (successor AnySuccessor:$successor, diff --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td --- a/mlir/test/mlir-tblgen/op-format-spec.td +++ b/mlir/test/mlir-tblgen/op-format-spec.td @@ -7,8 +7,8 @@ def TestDialect : Dialect { let name = "test"; } -class TestFormat_Op traits = []> - : Op { +class TestFormat_Op traits = []> + : Op { let assemblyFormat = fmt; } @@ -20,25 +20,25 @@ // attr-dict // CHECK: error: 'attr-dict' directive not found -def DirectiveAttrDictInvalidA : TestFormat_Op<"attrdict_invalid_a", [{ +def DirectiveAttrDictInvalidA : TestFormat_Op<[{ }]>; // CHECK: error: 'attr-dict' directive has already been seen -def DirectiveAttrDictInvalidB : TestFormat_Op<"attrdict_invalid_b", [{ +def DirectiveAttrDictInvalidB : TestFormat_Op<[{ attr-dict attr-dict }]>; // CHECK: error: 'attr-dict' directive has already been seen -def DirectiveAttrDictInvalidC : TestFormat_Op<"attrdict_invalid_c", [{ +def DirectiveAttrDictInvalidC : TestFormat_Op<[{ attr-dict attr-dict-with-keyword }]>; // CHECK: error: 'attr-dict' directive can only be used as a top-level directive -def DirectiveAttrDictInvalidD : TestFormat_Op<"attrdict_invalid_d", [{ +def DirectiveAttrDictInvalidD : TestFormat_Op<[{ type(attr-dict) }]>; // CHECK-NOT: error -def DirectiveAttrDictValidA : TestFormat_Op<"attrdict_valid_a", [{ +def DirectiveAttrDictValidA : TestFormat_Op<[{ attr-dict }]>; -def DirectiveAttrDictValidB : TestFormat_Op<"attrdict_valid_b", [{ +def DirectiveAttrDictValidB : TestFormat_Op<[{ attr-dict-with-keyword }]>; @@ -46,42 +46,42 @@ // custom // CHECK: error: expected '<' before custom directive name -def DirectiveCustomInvalidA : TestFormat_Op<"custom_invalid_a", [{ +def DirectiveCustomInvalidA : TestFormat_Op<[{ custom( }]>; // CHECK: error: expected custom directive name identifier -def DirectiveCustomInvalidB : TestFormat_Op<"custom_invalid_b", [{ +def DirectiveCustomInvalidB : TestFormat_Op<[{ custom<> }]>; // CHECK: error: expected '>' after custom directive name -def DirectiveCustomInvalidC : TestFormat_Op<"custom_invalid_c", [{ +def DirectiveCustomInvalidC : TestFormat_Op<[{ custom; // CHECK: error: expected '(' before custom directive parameters -def DirectiveCustomInvalidD : TestFormat_Op<"custom_invalid_d", [{ +def DirectiveCustomInvalidD : TestFormat_Op<[{ custom) }]>; // CHECK: error: only variables and types may be used as parameters to a custom directive -def DirectiveCustomInvalidE : TestFormat_Op<"custom_invalid_e", [{ +def DirectiveCustomInvalidE : TestFormat_Op<[{ custom(operands) }]>; // CHECK: error: expected ')' after custom directive parameters -def DirectiveCustomInvalidF : TestFormat_Op<"custom_invalid_f", [{ +def DirectiveCustomInvalidF : TestFormat_Op<[{ custom($operand< }]>, Arguments<(ins I64:$operand)>; // CHECK: error: type directives within a custom directive may only refer to variables -def DirectiveCustomInvalidH : TestFormat_Op<"custom_invalid_h", [{ +def DirectiveCustomInvalidH : TestFormat_Op<[{ custom(type(operands)) }]>; // CHECK-NOT: error -def DirectiveCustomValidA : TestFormat_Op<"custom_valid_a", [{ +def DirectiveCustomValidA : TestFormat_Op<[{ custom($operand) attr-dict }]>, Arguments<(ins Optional:$operand)>; -def DirectiveCustomValidB : TestFormat_Op<"custom_valid_b", [{ +def DirectiveCustomValidB : TestFormat_Op<[{ custom($operand, type($operand), type($result)) attr-dict }]>, Arguments<(ins I64:$operand)>, Results<(outs I64:$result)>; -def DirectiveCustomValidC : TestFormat_Op<"custom_valid_c", [{ +def DirectiveCustomValidC : TestFormat_Op<[{ custom($attr) attr-dict }]>, Arguments<(ins I64Attr:$attr)>; @@ -89,31 +89,31 @@ // functional-type // CHECK: error: 'functional-type' is only valid as a top-level directive -def DirectiveFunctionalTypeInvalidA : TestFormat_Op<"functype_invalid_a", [{ +def DirectiveFunctionalTypeInvalidA : TestFormat_Op<[{ functional-type(functional-type) }]>; // CHECK: error: expected '(' before argument list -def DirectiveFunctionalTypeInvalidB : TestFormat_Op<"functype_invalid_b", [{ +def DirectiveFunctionalTypeInvalidB : TestFormat_Op<[{ functional-type }]>; // CHECK: error: expected directive, literal, variable, or optional group -def DirectiveFunctionalTypeInvalidC : TestFormat_Op<"functype_invalid_c", [{ +def DirectiveFunctionalTypeInvalidC : TestFormat_Op<[{ functional-type( }]>; // CHECK: error: expected ',' after inputs argument -def DirectiveFunctionalTypeInvalidD : TestFormat_Op<"functype_invalid_d", [{ +def DirectiveFunctionalTypeInvalidD : TestFormat_Op<[{ functional-type(operands }]>; // CHECK: error: expected directive, literal, variable, or optional group -def DirectiveFunctionalTypeInvalidE : TestFormat_Op<"functype_invalid_e", [{ +def DirectiveFunctionalTypeInvalidE : TestFormat_Op<[{ functional-type(operands, }]>; // CHECK: error: expected ')' after argument list -def DirectiveFunctionalTypeInvalidF : TestFormat_Op<"functype_invalid_f", [{ +def DirectiveFunctionalTypeInvalidF : TestFormat_Op<[{ functional-type(operands, results }]>; // CHECK-NOT: error -def DirectiveFunctionalTypeValid : TestFormat_Op<"functype_invalid_a", [{ +def DirectiveFunctionalTypeValid : TestFormat_Op<[{ functional-type(operands, results) attr-dict }]>; @@ -121,45 +121,128 @@ // operands // CHECK: error: 'operands' directive creates overlap in format -def DirectiveOperandsInvalidA : TestFormat_Op<"operands_invalid_a", [{ +def DirectiveOperandsInvalidA : TestFormat_Op<[{ operands operands }]>; // CHECK: error: 'operands' directive creates overlap in format -def DirectiveOperandsInvalidB : TestFormat_Op<"operands_invalid_b", [{ +def DirectiveOperandsInvalidB : TestFormat_Op<[{ $operand operands }]>, Arguments<(ins I64:$operand)>; // CHECK-NOT: error: -def DirectiveOperandsValid : TestFormat_Op<"operands_valid", [{ +def DirectiveOperandsValid : TestFormat_Op<[{ operands attr-dict }]>; +//===----------------------------------------------------------------------===// +// ref + +// CHECK: error: 'ref' is only valid within a `custom` directive +def DirectiveRefInvalidA : TestFormat_Op<[{ + ref(type($operand)) +}]>, Arguments<(ins I64:$operand)>; + +// CHECK: error: 'ref' of 'type($operand)' is not bound by a prior 'type' directive +def DirectiveRefInvalidB : TestFormat_Op<[{ + custom(ref(type($operand))) +}]>, Arguments<(ins I64:$operand)>; + +// CHECK: error: 'ref' of 'type(operands)' is not bound by a prior 'type' directive +def DirectiveRefInvalidC : TestFormat_Op<[{ + custom(ref(type(operands))) +}]>; + +// CHECK: error: 'ref' of 'type($result)' is not bound by a prior 'type' directive +def DirectiveRefInvalidD : TestFormat_Op<[{ + custom(ref(type($result))) +}]>, Results<(outs I64:$result)>; + +// CHECK: error: 'ref' of 'type(results)' is not bound by a prior 'type' directive +def DirectiveRefInvalidE : TestFormat_Op<[{ + custom(ref(type(results))) +}]>; + +// CHECK: error: 'ref' of 'successors' is not bound by a prior 'successors' directive +def DirectiveRefInvalidF : TestFormat_Op<[{ + custom(ref(successors)) +}]>; + +// CHECK: error: 'ref' of 'regions' is not bound by a prior 'regions' directive +def DirectiveRefInvalidG : TestFormat_Op<[{ + custom(ref(regions)) +}]>; + +// CHECK: error: expected '(' before argument list +def DirectiveRefInvalidH : TestFormat_Op<[{ + custom(ref) +}]>; + +// CHECK: error: expected ')' after argument list +def DirectiveRefInvalidI : TestFormat_Op<[{ + operands custom(ref(operands( +}]>; + +// CHECK: error: 'ref' of 'operands' is not bound by a prior 'operands' directive +def DirectiveRefInvalidJ : TestFormat_Op<[{ + custom(ref(operands)) +}]>; + +// CHECK: error: 'ref' of 'attr-dict' is not bound by a prior 'attr-dict' directive +def DirectiveRefInvalidK : TestFormat_Op<[{ + custom(ref(attr-dict)) +}]>; + +// CHECK: error: successor 'successor' must be bound before it is referenced +def DirectiveRefInvalidL : TestFormat_Op<[{ + custom(ref($successor)) +}]> { + let successors = (successor AnySuccessor:$successor); +} + +// CHECK: error: region 'region' must be bound before it is referenced +def DirectiveRefInvalidM : TestFormat_Op<[{ + custom(ref($region)) +}]> { + let regions = (region AnyRegion:$region); +} + +// CHECK: error: attribute 'attr' must be bound before it is referenced +def DirectiveRefInvalidN : TestFormat_Op<[{ + custom(ref($attr)) +}]>, Arguments<(ins I64Attr:$attr)>; + + +// CHECK: error: operand 'operand' must be bound before it is referenced +def DirectiveRefInvalidO : TestFormat_Op<[{ + custom(ref($operand)) +}]>, Arguments<(ins I64:$operand)>; + //===----------------------------------------------------------------------===// // regions // CHECK: error: 'regions' directive creates overlap in format -def DirectiveRegionsInvalidA : TestFormat_Op<"regions_invalid_a", [{ +def DirectiveRegionsInvalidA : TestFormat_Op<[{ regions regions attr-dict }]>; // CHECK: error: 'regions' directive creates overlap in format -def DirectiveRegionsInvalidB : TestFormat_Op<"regions_invalid_b", [{ +def DirectiveRegionsInvalidB : TestFormat_Op<[{ $region regions attr-dict }]> { let regions = (region AnyRegion:$region); } // CHECK: error: 'regions' is only valid as a top-level directive -def DirectiveRegionsInvalidC : TestFormat_Op<"regions_invalid_c", [{ +def DirectiveRegionsInvalidC : TestFormat_Op<[{ type(regions) }]>; // CHECK-NOT: error: -def DirectiveRegionsValid : TestFormat_Op<"regions_valid", [{ +def DirectiveRegionsValid : TestFormat_Op<[{ regions attr-dict }]>; //===----------------------------------------------------------------------===// // results -// CHECK: error: 'results' directive can not be used as a top-level directive -def DirectiveResultsInvalidA : TestFormat_Op<"results_invalid_a", [{ +// CHECK: error: 'results' directive can can only be used as a child to a 'type' directive +def DirectiveResultsInvalidA : TestFormat_Op<[{ results }]>; @@ -167,7 +250,7 @@ // successors // CHECK: error: 'successors' is only valid as a top-level directive -def DirectiveSuccessorsInvalidA : TestFormat_Op<"successors_invalid_a", [{ +def DirectiveSuccessorsInvalidA : TestFormat_Op<[{ type(successors) }]>; @@ -175,140 +258,78 @@ // type // CHECK: error: expected '(' before argument list -def DirectiveTypeInvalidA : TestFormat_Op<"type_invalid_a", [{ +def DirectiveTypeInvalidA : TestFormat_Op<[{ type }]>; // CHECK: error: expected directive, literal, variable, or optional group -def DirectiveTypeInvalidB : TestFormat_Op<"type_invalid_b", [{ +def DirectiveTypeInvalidB : TestFormat_Op<[{ type( }]>; // CHECK: error: expected ')' after argument list -def DirectiveTypeInvalidC : TestFormat_Op<"type_invalid_c", [{ +def DirectiveTypeInvalidC : TestFormat_Op<[{ type(operands }]>; // CHECK-NOT: error: -def DirectiveTypeValid : TestFormat_Op<"type_valid", [{ +def DirectiveTypeValid : TestFormat_Op<[{ type(operands) attr-dict }]>; //===----------------------------------------------------------------------===// // functional-type/type operands -// CHECK: error: 'type' directive operand expects variable or directive operand -def DirectiveTypeZOperandInvalidA : TestFormat_Op<"type_operand_invalid_a", [{ +// CHECK: error: literals may only be used in a top-level section of the format +def DirectiveTypeZOperandInvalidA : TestFormat_Op<[{ type(`literal`) }]>; // CHECK: error: 'operands' 'type' is already bound -def DirectiveTypeZOperandInvalidB : TestFormat_Op<"type_operand_invalid_b", [{ +def DirectiveTypeZOperandInvalidB : TestFormat_Op<[{ type(operands) type(operands) }]>; // CHECK: error: 'operands' 'type' is already bound -def DirectiveTypeZOperandInvalidC : TestFormat_Op<"type_operand_invalid_c", [{ +def DirectiveTypeZOperandInvalidC : TestFormat_Op<[{ type($operand) type(operands) }]>, Arguments<(ins I64:$operand)>; // CHECK: error: 'type' of 'operand' is already bound -def DirectiveTypeZOperandInvalidD : TestFormat_Op<"type_operand_invalid_d", [{ +def DirectiveTypeZOperandInvalidD : TestFormat_Op<[{ type(operands) type($operand) }]>, Arguments<(ins I64:$operand)>; // CHECK: error: 'type' of 'operand' is already bound -def DirectiveTypeZOperandInvalidE : TestFormat_Op<"type_operand_invalid_e", [{ +def DirectiveTypeZOperandInvalidE : TestFormat_Op<[{ type($operand) type($operand) }]>, Arguments<(ins I64:$operand)>; // CHECK: error: 'results' 'type' is already bound -def DirectiveTypeZOperandInvalidF : TestFormat_Op<"type_operand_invalid_f", [{ +def DirectiveTypeZOperandInvalidF : TestFormat_Op<[{ type(results) type(results) }]>; // CHECK: error: 'results' 'type' is already bound -def DirectiveTypeZOperandInvalidG : TestFormat_Op<"type_operand_invalid_g", [{ +def DirectiveTypeZOperandInvalidG : TestFormat_Op<[{ type($result) type(results) }]>, Results<(outs I64:$result)>; // CHECK: error: 'type' of 'result' is already bound -def DirectiveTypeZOperandInvalidH : TestFormat_Op<"type_operand_invalid_h", [{ +def DirectiveTypeZOperandInvalidH : TestFormat_Op<[{ type(results) type($result) }]>, Results<(outs I64:$result)>; // CHECK: error: 'type' of 'result' is already bound -def DirectiveTypeZOperandInvalidI : TestFormat_Op<"type_operand_invalid_i", [{ +def DirectiveTypeZOperandInvalidI : TestFormat_Op<[{ type($result) type($result) }]>, Results<(outs I64:$result)>; -//===----------------------------------------------------------------------===// -// type_ref - -// CHECK: error: 'type_ref' of 'operand' is not bound by a prior 'type' directive -def DirectiveTypeZZTypeRefOperandInvalidC : TestFormat_Op<"type_ref_operand_invalid_c", [{ - type_ref($operand) type(operands) -}]>, Arguments<(ins I64:$operand)>; -// CHECK: error: 'operands' 'type_ref' is not bound by a prior 'type' directive -def DirectiveTypeZZTypeRefOperandInvalidD : TestFormat_Op<"type_ref_operand_invalid_d", [{ - type_ref(operands) type($operand) -}]>, Arguments<(ins I64:$operand)>; -// CHECK: error: 'type_ref' of 'operand' is not bound by a prior 'type' directive -def DirectiveTypeZZTypeRefOperandInvalidE : TestFormat_Op<"type_ref_operand_invalid_e", [{ - type_ref($operand) type($operand) -}]>, Arguments<(ins I64:$operand)>; -// CHECK: error: 'type_ref' of 'result' is not bound by a prior 'type' directive -def DirectiveTypeZZTypeRefOperandInvalidG : TestFormat_Op<"type_ref_operand_invalid_g", [{ - type_ref($result) type(results) -}]>, Results<(outs I64:$result)>; -// CHECK: error: 'results' 'type_ref' is not bound by a prior 'type' directive -def DirectiveTypeZZTypeRefOperandInvalidH : TestFormat_Op<"type_ref_operand_invalid_h", [{ - type_ref(results) type($result) -}]>, Results<(outs I64:$result)>; -// CHECK: error: 'type_ref' of 'result' is not bound by a prior 'type' directive -def DirectiveTypeZZTypeRefOperandInvalidI : TestFormat_Op<"type_ref_operand_invalid_i", [{ - type_ref($result) type($result) -}]>, Results<(outs I64:$result)>; - -// CHECK-NOT: error -def DirectiveTypeZZTypeRefOperandB : TestFormat_Op<"type_ref_operand_valid_b", [{ - type_ref(operands) attr-dict -}]>; -// CHECK-NOT: error -def DirectiveTypeZZTypeRefOperandD : TestFormat_Op<"type_ref_operand_valid_d", [{ - type(operands) type_ref($operand) attr-dict -}]>, Arguments<(ins I64:$operand)>; -// CHECK-NOT: error -def DirectiveTypeZZTypeRefOperandE : TestFormat_Op<"type_ref_operand_valid_e", [{ - type($operand) type_ref($operand) attr-dict -}]>, Arguments<(ins I64:$operand)>; -// CHECK-NOT: error -def DirectiveTypeZZTypeRefOperandF : TestFormat_Op<"type_ref_operand_valid_f", [{ - type(results) type_ref(results) attr-dict -}]>; -// CHECK-NOT: error -def DirectiveTypeZZTypeRefOperandG : TestFormat_Op<"type_ref_operand_valid_g", [{ - type($result) type_ref(results) attr-dict -}]>, Results<(outs I64:$result)>; -// CHECK-NOT: error -def DirectiveTypeZZTypeRefOperandH : TestFormat_Op<"type_ref_operand_valid_h", [{ - type(results) type_ref($result) attr-dict -}]>, Results<(outs I64:$result)>; -// CHECK-NOT: error -def DirectiveTypeZZTypeRefOperandI : TestFormat_Op<"type_ref_operand_valid_i", [{ - type($result) type_ref($result) attr-dict -}]>, Results<(outs I64:$result)>; - -// CHECK-NOT: error: -def DirectiveTypeZZZOperandValid : TestFormat_Op<"type_operand_valid", [{ - type(operands) type(results) attr-dict -}]>; - //===----------------------------------------------------------------------===// // Literals //===----------------------------------------------------------------------===// // Test all of the valid literals. // CHECK: error: expected valid literal -def LiteralInvalidA : TestFormat_Op<"literal_invalid_a", [{ +def LiteralInvalidA : TestFormat_Op<[{ `1` }]>; // CHECK: error: unexpected end of file in literal // CHECK: error: expected directive, literal, variable, or optional group -def LiteralInvalidB : TestFormat_Op<"literal_invalid_b", [{ +def LiteralInvalidB : TestFormat_Op<[{ ` }]>; // CHECK-NOT: error -def LiteralValid : TestFormat_Op<"literal_valid", [{ +def LiteralValid : TestFormat_Op<[{ `_` `:` `,` `=` `<` `>` `(` `)` `[` `]` `?` `+` `*` ` ` `` `->` `\n` `abc$._` attr-dict }]>; @@ -318,60 +339,60 @@ //===----------------------------------------------------------------------===// // CHECK: error: optional groups can only be used as top-level elements -def OptionalInvalidA : TestFormat_Op<"optional_invalid_a", [{ +def OptionalInvalidA : TestFormat_Op<[{ type(($attr^)?) attr-dict }]>, Arguments<(ins OptionalAttr:$attr)>; // CHECK: error: expected directive, literal, variable, or optional group -def OptionalInvalidB : TestFormat_Op<"optional_invalid_b", [{ +def OptionalInvalidB : TestFormat_Op<[{ () attr-dict }]>, Arguments<(ins OptionalAttr:$attr)>; // CHECK: error: optional group specified no anchor element -def OptionalInvalidC : TestFormat_Op<"optional_invalid_c", [{ +def OptionalInvalidC : TestFormat_Op<[{ ($attr)? attr-dict }]>, Arguments<(ins OptionalAttr:$attr)>; // CHECK: error: first parsable element of an operand group must be an attribute, literal, operand, or region -def OptionalInvalidD : TestFormat_Op<"optional_invalid_d", [{ +def OptionalInvalidD : TestFormat_Op<[{ (type($operand) $operand^)? attr-dict }]>, Arguments<(ins Optional:$operand)>; // CHECK: error: only literals, types, and variables can be used within an optional group -def OptionalInvalidE : TestFormat_Op<"optional_invalid_e", [{ +def OptionalInvalidE : TestFormat_Op<[{ (`,` $attr^ type(operands))? attr-dict }]>, Arguments<(ins OptionalAttr:$attr)>; // CHECK: error: only one element can be marked as the anchor of an optional group -def OptionalInvalidF : TestFormat_Op<"optional_invalid_f", [{ +def OptionalInvalidF : TestFormat_Op<[{ ($attr^ $attr2^) attr-dict }]>, Arguments<(ins OptionalAttr:$attr, OptionalAttr:$attr2)>; // CHECK: error: only optional attributes can be used to anchor an optional group -def OptionalInvalidG : TestFormat_Op<"optional_invalid_g", [{ +def OptionalInvalidG : TestFormat_Op<[{ ($attr^) attr-dict }]>, Arguments<(ins I64Attr:$attr)>; // CHECK: error: only variable length operands can be used within an optional group -def OptionalInvalidH : TestFormat_Op<"optional_invalid_h", [{ +def OptionalInvalidH : TestFormat_Op<[{ ($arg^) attr-dict }]>, Arguments<(ins I64:$arg)>; // CHECK: error: only literals, types, and variables can be used within an optional group -def OptionalInvalidI : TestFormat_Op<"optional_invalid_i", [{ +def OptionalInvalidI : TestFormat_Op<[{ (functional-type($arg, results)^)? attr-dict }]>, Arguments<(ins Variadic:$arg)>; // CHECK: error: only literals, types, and variables can be used within an optional group -def OptionalInvalidJ : TestFormat_Op<"optional_invalid_j", [{ +def OptionalInvalidJ : TestFormat_Op<[{ (attr-dict) }]>; // CHECK: error: expected '?' after optional group -def OptionalInvalidK : TestFormat_Op<"optional_invalid_k", [{ +def OptionalInvalidK : TestFormat_Op<[{ ($arg^) }]>, Arguments<(ins Variadic:$arg)>; // CHECK: error: only variables and types can be used to anchor an optional group -def OptionalInvalidL : TestFormat_Op<"optional_invalid_l", [{ +def OptionalInvalidL : TestFormat_Op<[{ (custom($arg)^)? }]>, Arguments<(ins I64:$arg)>; // CHECK: error: only variables and types can be used to anchor an optional group -def OptionalInvalidM : TestFormat_Op<"optional_invalid_m", [{ +def OptionalInvalidM : TestFormat_Op<[{ (` `^)? }]>, Arguments<(ins)>; // CHECK-NOT: error -def OptionalValidA : TestFormat_Op<"optional_valid_a", [{ +def OptionalValidA : TestFormat_Op<[{ (` ` `` $arg^)? }]>; @@ -380,78 +401,78 @@ //===----------------------------------------------------------------------===// // CHECK: error: expected variable to refer to an argument, region, result, or successor -def VariableInvalidA : TestFormat_Op<"variable_invalid_a", [{ +def VariableInvalidA : TestFormat_Op<[{ $unknown_arg attr-dict }]>; // CHECK: error: attribute 'attr' is already bound -def VariableInvalidB : TestFormat_Op<"variable_invalid_b", [{ +def VariableInvalidB : TestFormat_Op<[{ $attr $attr attr-dict }]>, Arguments<(ins I64Attr:$attr)>; // CHECK: error: operand 'operand' is already bound -def VariableInvalidC : TestFormat_Op<"variable_invalid_c", [{ +def VariableInvalidC : TestFormat_Op<[{ $operand $operand attr-dict }]>, Arguments<(ins I64:$operand)>; // CHECK: error: operand 'operand' is already bound -def VariableInvalidD : TestFormat_Op<"variable_invalid_d", [{ +def VariableInvalidD : TestFormat_Op<[{ operands $operand attr-dict }]>, Arguments<(ins I64:$operand)>; -// CHECK: error: results can not be used at the top level -def VariableInvalidE : TestFormat_Op<"variable_invalid_e", [{ +// CHECK: error: result variables can can only be used as a child to a 'type' directive +def VariableInvalidE : TestFormat_Op<[{ $result attr-dict }]>, Results<(outs I64:$result)>; // CHECK: error: successor 'successor' is already bound -def VariableInvalidF : TestFormat_Op<"variable_invalid_f", [{ +def VariableInvalidF : TestFormat_Op<[{ $successor $successor attr-dict }]> { let successors = (successor AnySuccessor:$successor); } // CHECK: error: successor 'successor' is already bound -def VariableInvalidG : TestFormat_Op<"variable_invalid_g", [{ +def VariableInvalidG : TestFormat_Op<[{ successors $successor attr-dict }]> { let successors = (successor AnySuccessor:$successor); } // CHECK: error: format ambiguity caused by `:` literal found after attribute `attr` which does not have a buildable type -def VariableInvalidH : TestFormat_Op<"variable_invalid_h", [{ +def VariableInvalidH : TestFormat_Op<[{ $attr `:` attr-dict }]>, Arguments<(ins ElementsAttr:$attr)>; // CHECK: error: format ambiguity caused by `:` literal found after attribute `attr` which does not have a buildable type -def VariableInvalidI : TestFormat_Op<"variable_invalid_i", [{ +def VariableInvalidI : TestFormat_Op<[{ (`foo` $attr^)? `:` attr-dict }]>, Arguments<(ins OptionalAttr:$attr)>; // CHECK: error: format ambiguity caused by `:` literal found after attribute `attr` which does not have a buildable type -def VariableInvalidJ : TestFormat_Op<"variable_invalid_j", [{ +def VariableInvalidJ : TestFormat_Op<[{ $attr ` ` `:` attr-dict }]>, Arguments<(ins ElementsAttr:$attr)>; // CHECK: error: region 'region' is already bound -def VariableInvalidK : TestFormat_Op<"variable_invalid_k", [{ +def VariableInvalidK : TestFormat_Op<[{ $region $region attr-dict }]> { let regions = (region AnyRegion:$region); } // CHECK: error: region 'region' is already bound -def VariableInvalidL : TestFormat_Op<"variable_invalid_l", [{ +def VariableInvalidL : TestFormat_Op<[{ regions $region attr-dict }]> { let regions = (region AnyRegion:$region); } // CHECK: error: regions can only be used at the top level -def VariableInvalidM : TestFormat_Op<"variable_invalid_m", [{ +def VariableInvalidM : TestFormat_Op<[{ type($region) }]> { let regions = (region AnyRegion:$region); } // CHECK: error: region #0, named 'region', not found -def VariableInvalidN : TestFormat_Op<"variable_invalid_n", [{ +def VariableInvalidN : TestFormat_Op<[{ attr-dict }]> { let regions = (region AnyRegion:$region); } // CHECK-NOT: error: -def VariableValidA : TestFormat_Op<"variable_valid_a", [{ +def VariableValidA : TestFormat_Op<[{ $attr `:` attr-dict }]>, Arguments<(ins OptionalAttr:$attr)>; -def VariableValidB : TestFormat_Op<"variable_valid_b", [{ +def VariableValidB : TestFormat_Op<[{ (`foo` $attr^)? `:` attr-dict }]>, Arguments<(ins OptionalAttr:$attr)>; @@ -461,75 +482,75 @@ // CHECK: error: type of result #0, named 'result', is not buildable and a buildable type cannot be inferred // CHECK: note: suggest adding a type constraint to the operation or adding a 'type($result)' directive to the custom assembly format -def ZCoverageInvalidA : TestFormat_Op<"variable_invalid_a", [{ +def ZCoverageInvalidA : TestFormat_Op<[{ attr-dict }]>, Arguments<(ins AnyMemRef:$operand)>, Results<(outs AnyMemRef:$result)>; // CHECK: error: operand #0, named 'operand', not found // CHECK: note: suggest adding a '$operand' directive to the custom assembly format -def ZCoverageInvalidB : TestFormat_Op<"variable_invalid_b", [{ +def ZCoverageInvalidB : TestFormat_Op<[{ type($result) attr-dict }]>, Arguments<(ins AnyMemRef:$operand)>, Results<(outs AnyMemRef:$result)>; // CHECK: error: type of operand #0, named 'operand', is not buildable and a buildable type cannot be inferred // CHECK: note: suggest adding a type constraint to the operation or adding a 'type($operand)' directive to the custom assembly format -def ZCoverageInvalidC : TestFormat_Op<"variable_invalid_c", [{ +def ZCoverageInvalidC : TestFormat_Op<[{ $operand type($result) attr-dict }]>, Arguments<(ins AnyMemRef:$operand)>, Results<(outs AnyMemRef:$result)>; // CHECK: error: type of operand #0, named 'operand', is not buildable and a buildable type cannot be inferred // CHECK: note: suggest adding a type constraint to the operation or adding a 'type($operand)' directive to the custom assembly format -def ZCoverageInvalidD : TestFormat_Op<"variable_invalid_d", [{ +def ZCoverageInvalidD : TestFormat_Op<[{ operands attr-dict }]>, Arguments<(ins Variadic:$operand)>; // CHECK: error: type of result #0, named 'result', is not buildable and a buildable type cannot be inferred // CHECK: note: suggest adding a type constraint to the operation or adding a 'type($result)' directive to the custom assembly format -def ZCoverageInvalidE : TestFormat_Op<"variable_invalid_e", [{ +def ZCoverageInvalidE : TestFormat_Op<[{ attr-dict }]>, Results<(outs Variadic:$result)>; // CHECK: error: successor #0, named 'successor', not found // CHECK: note: suggest adding a '$successor' directive to the custom assembly format -def ZCoverageInvalidF : TestFormat_Op<"variable_invalid_f", [{ +def ZCoverageInvalidF : TestFormat_Op<[{ attr-dict }]> { let successors = (successor AnySuccessor:$successor); } // CHECK: error: type of operand #0, named 'operand', is not buildable and a buildable type cannot be inferred // CHECK: note: suggest adding a type constraint to the operation or adding a 'type($operand)' directive to the custom assembly format -def ZCoverageInvalidG : TestFormat_Op<"variable_invalid_g", [{ +def ZCoverageInvalidG : TestFormat_Op<[{ operands attr-dict }]>, Arguments<(ins Optional:$operand)>; // CHECK: error: type of result #0, named 'result', is not buildable and a buildable type cannot be inferred // CHECK: note: suggest adding a type constraint to the operation or adding a 'type($result)' directive to the custom assembly format -def ZCoverageInvalidH : TestFormat_Op<"variable_invalid_h", [{ +def ZCoverageInvalidH : TestFormat_Op<[{ attr-dict }]>, Results<(outs Optional:$result)>; // CHECK-NOT: error -def ZCoverageValidA : TestFormat_Op<"variable_valid_a", [{ +def ZCoverageValidA : TestFormat_Op<[{ $operand type($operand) type($result) attr-dict }]>, Arguments<(ins AnyMemRef:$operand)>, Results<(outs AnyMemRef:$result)>; -def ZCoverageValidB : TestFormat_Op<"variable_valid_b", [{ +def ZCoverageValidB : TestFormat_Op<[{ $operand type(operands) type(results) attr-dict }]>, Arguments<(ins AnyMemRef:$operand)>, Results<(outs AnyMemRef:$result)>; -def ZCoverageValidC : TestFormat_Op<"variable_valid_c", [{ +def ZCoverageValidC : TestFormat_Op<[{ operands functional-type(operands, results) attr-dict }]>, Arguments<(ins AnyMemRef:$operand)>, Results<(outs AnyMemRef:$result)>; // Check that we can infer type equalities from certain traits. -def ZCoverageValidD : TestFormat_Op<"variable_valid_d", [{ +def ZCoverageValidD : TestFormat_Op<[{ operands type($result) attr-dict }], [SameOperandsAndResultType]>, Arguments<(ins AnyMemRef:$operand)>, Results<(outs AnyMemRef:$result)>; -def ZCoverageValidE : TestFormat_Op<"variable_valid_e", [{ +def ZCoverageValidE : TestFormat_Op<[{ $operand type($operand) attr-dict }], [SameOperandsAndResultType]>, Arguments<(ins AnyMemRef:$operand)>, Results<(outs AnyMemRef:$result)>; -def ZCoverageValidF : TestFormat_Op<"variable_valid_f", [{ +def ZCoverageValidF : TestFormat_Op<[{ operands type($other) attr-dict }], [SameTypeOperands]>, Arguments<(ins AnyMemRef:$operand, AnyMemRef:$other)>; -def ZCoverageValidG : TestFormat_Op<"variable_valid_g", [{ +def ZCoverageValidG : TestFormat_Op<[{ operands type($other) attr-dict }], [AllTypesMatch<["operand", "other"]>]>, Arguments<(ins AnyMemRef:$operand, AnyMemRef:$other)>; -def ZCoverageValidH : TestFormat_Op<"variable_valid_h", [{ +def ZCoverageValidH : TestFormat_Op<[{ operands type($result) attr-dict }], [AllTypesMatch<["operand", "result"]>]>, Arguments<(ins AnyMemRef:$operand)>, Results<(outs AnyMemRef:$result)>; diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir --- a/mlir/test/mlir-tblgen/op-format.mlir +++ b/mlir/test/mlir-tblgen/op-format.mlir @@ -291,6 +291,12 @@ // CHECK: test.format_custom_directive_results_with_type_refs : i64 -> (i64) type_refs_capture : i64 -> (i64) test.format_custom_directive_results_with_type_refs : i64 -> (i64) type_refs_capture : i64 -> (i64) +// CHECK: test.format_custom_directive_with_optional_operand_ref %[[I64]] : 1 +test.format_custom_directive_with_optional_operand_ref %i64 : 1 + +// CHECK: test.format_custom_directive_with_optional_operand_ref : 0 +test.format_custom_directive_with_optional_operand_ref : 0 + func @foo() { // CHECK: test.format_custom_directive_successors ^bb1, ^bb2 test.format_custom_directive_successors ^bb1, ^bb2 diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -58,11 +58,11 @@ CustomDirective, FunctionalTypeDirective, OperandsDirective, + RefDirective, RegionsDirective, ResultsDirective, SuccessorsDirective, TypeDirective, - TypeRefDirective, /// This element is a literal. Literal, @@ -234,10 +234,10 @@ std::unique_ptr inputs, results; }; -/// This class represents the `type` directive. -class TypeDirective : public DirectiveElement { +/// This class represents the `ref` directive. +class RefDirective : public DirectiveElement { public: - TypeDirective(std::unique_ptr arg) : operand(std::move(arg)) {} + RefDirective(std::unique_ptr arg) : operand(std::move(arg)) {} Element *getOperand() const { return operand.get(); } private: @@ -245,11 +245,10 @@ std::unique_ptr operand; }; -/// This class represents the `type_ref` directive. -class TypeRefDirective - : public DirectiveElement { +/// This class represents the `type` directive. +class TypeDirective : public DirectiveElement { public: - TypeRefDirective(std::unique_ptr arg) : operand(std::move(arg)) {} + TypeDirective(std::unique_ptr arg) : operand(std::move(arg)) {} Element *getOperand() const { return operand.get(); } private: @@ -873,19 +872,6 @@ << llvm::formatv( " ::llvm::ArrayRef<::mlir::Type> {0}Types({0}RawTypes);\n", name); - } else if (auto *dir = dyn_cast(element)) { - ArgumentLengthKind lengthKind; - StringRef name = getTypeListName(dir->getOperand(), lengthKind); - // Refer to the previously encountered TypeDirective for name. - // Take a `const ::mlir::SmallVector<::mlir::Type, 1> &` in the declaration - // to properly track the types that will be parsed and pushed later on. - if (lengthKind != ArgumentLengthKind::Single) - body << " const ::mlir::SmallVector<::mlir::Type, 1> &" << name - << "TypesRef(" << name << "Types);\n"; - else - body << llvm::formatv( - " ::llvm::ArrayRef<::mlir::Type> {0}RawTypesRef({0}RawTypes);\n", - name); } else if (auto *dir = dyn_cast(element)) { ArgumentLengthKind ignored; body << " ::llvm::ArrayRef<::mlir::Type> " @@ -897,7 +883,6 @@ /// Generate the parser for a parameter to a custom directive. static void genCustomParameterParser(Element ¶m, OpMethodBody &body) { - body << ", "; if (auto *attr = dyn_cast(¶m)) { body << attr->getVar()->name << "Attr"; } else if (isa(¶m)) { @@ -926,15 +911,9 @@ else body << llvm::formatv("{0}Successor", name); - } else if (auto *dir = dyn_cast(¶m)) { - ArgumentLengthKind lengthKind; - StringRef listName = getTypeListName(dir->getOperand(), lengthKind); - if (lengthKind == ArgumentLengthKind::Variadic) - body << llvm::formatv("{0}TypesRef", listName); - else if (lengthKind == ArgumentLengthKind::Optional) - body << llvm::formatv("{0}TypeRef", listName); - else - body << formatv("{0}RawTypesRef[0]", listName); + } else if (auto *dir = dyn_cast(¶m)) { + genCustomParameterParser(*dir->getOperand(), body); + } else if (auto *dir = dyn_cast(¶m)) { ArgumentLengthKind lengthKind; StringRef listName = getTypeListName(dir->getOperand(), lengthKind); @@ -967,27 +946,39 @@ "{0}Operand;\n", operand->getVar()->name); } - } else if (auto *dir = dyn_cast(¶m)) { - // Reference to an optional which may or may not have been set. - // Retrieve from vector if not empty. - ArgumentLengthKind lengthKind; - StringRef listName = getTypeListName(dir->getOperand(), lengthKind); - if (lengthKind == ArgumentLengthKind::Optional) - body << llvm::formatv( - " ::mlir::Type {0}TypeRef = {0}TypesRef.empty() " - "? Type() : {0}TypesRef[0];\n", - listName); } else if (auto *dir = dyn_cast(¶m)) { ArgumentLengthKind lengthKind; StringRef listName = getTypeListName(dir->getOperand(), lengthKind); if (lengthKind == ArgumentLengthKind::Optional) body << llvm::formatv(" ::mlir::Type {0}Type;\n", listName); + } else if (auto *dir = dyn_cast(¶m)) { + Element *input = dir->getOperand(); + if (auto *operand = dyn_cast(input)) { + if (!operand->getVar()->isOptional()) + continue; + body << llvm::formatv( + " {0} {1}Operand = {1}Operands.empty() ? {0}() : " + "{1}Operands[0];\n", + "llvm::Optional<::mlir::OpAsmParser::OperandType>", + operand->getVar()->name); + + } else if (auto *type = dyn_cast(input)) { + ArgumentLengthKind lengthKind; + StringRef listName = getTypeListName(type->getOperand(), lengthKind); + if (lengthKind == ArgumentLengthKind::Optional) { + body << llvm::formatv(" ::mlir::Type {0}Type = {0}Types.empty() ? " + "::mlir::Type() : {0}Types[0];\n", + listName); + } + } } } body << " if (parse" << dir->getName() << "(parser"; - for (Element ¶m : dir->getArguments()) + for (Element ¶m : dir->getArguments()) { + body << ", "; genCustomParameterParser(param, body); + } body << "))\n" << " return ::mlir::failure();\n"; @@ -1008,9 +999,6 @@ body << llvm::formatv(" if ({0}Operand.hasValue())\n" " {0}Operands.push_back(*{0}Operand);\n", var->name); - } else if (isa(¶m)) { - // In the `type_ref` case, do not parse a new Type that needs to be added. - // Just do nothing here. } else if (auto *dir = dyn_cast(¶m)) { ArgumentLengthKind lengthKind; StringRef listName = getTypeListName(dir->getOperand(), lengthKind); @@ -1238,15 +1226,6 @@ } else if (isa(element)) { body << llvm::formatv(successorListParserCode, "full"); - } else if (auto *dir = dyn_cast(element)) { - ArgumentLengthKind lengthKind; - StringRef listName = getTypeListName(dir->getOperand(), lengthKind); - if (lengthKind == ArgumentLengthKind::Variadic) - body << llvm::formatv(variadicTypeParserCode, listName); - else if (lengthKind == ArgumentLengthKind::Optional) - body << llvm::formatv(optionalTypeParserCode, listName); - else - body << formatv(typeParserCode, listName); } else if (auto *dir = dyn_cast(element)) { ArgumentLengthKind lengthKind; StringRef listName = getTypeListName(dir->getOperand(), lengthKind); @@ -1587,54 +1566,51 @@ shouldEmitSpace = false; } +/// Generate the printer for a custom directive parameter. +static void genCustomDirectiveParameterPrinter(Element *element, + OpMethodBody &body) { + if (auto *attr = dyn_cast(element)) { + body << attr->getVar()->name << "Attr()"; + + } else if (isa(element)) { + body << "getOperation()->getAttrDictionary()"; + + } else if (auto *operand = dyn_cast(element)) { + body << operand->getVar()->name << "()"; + + } else if (auto *region = dyn_cast(element)) { + body << region->getVar()->name << "()"; + + } else if (auto *successor = dyn_cast(element)) { + body << successor->getVar()->name << "()"; + + } else if (auto *dir = dyn_cast(element)) { + genCustomDirectiveParameterPrinter(dir->getOperand(), body); + + } else if (auto *dir = dyn_cast(element)) { + auto *typeOperand = dir->getOperand(); + auto *operand = dyn_cast(typeOperand); + auto *var = operand ? operand->getVar() + : cast(typeOperand)->getVar(); + if (var->isVariadic()) + body << var->name << "().getTypes()"; + else if (var->isOptional()) + body << llvm::formatv("({0}() ? {0}().getType() : Type())", var->name); + else + body << var->name << "().getType()"; + } else { + llvm_unreachable("unknown custom directive parameter"); + } +} + /// Generate the printer for a custom directive. static void genCustomDirectivePrinter(CustomDirective *customDir, OpMethodBody &body) { body << " print" << customDir->getName() << "(p, *this"; for (Element ¶m : customDir->getArguments()) { body << ", "; - if (auto *attr = dyn_cast(¶m)) { - body << attr->getVar()->name << "Attr()"; - - } else if (isa(¶m)) { - body << "getOperation()->getAttrDictionary()"; - - } else if (auto *operand = dyn_cast(¶m)) { - body << operand->getVar()->name << "()"; - - } else if (auto *region = dyn_cast(¶m)) { - body << region->getVar()->name << "()"; - - } else if (auto *successor = dyn_cast(¶m)) { - body << successor->getVar()->name << "()"; - - } else if (auto *dir = dyn_cast(¶m)) { - auto *typeOperand = dir->getOperand(); - auto *operand = dyn_cast(typeOperand); - auto *var = operand ? operand->getVar() - : cast(typeOperand)->getVar(); - if (var->isVariadic()) - body << var->name << "().getTypes()"; - else if (var->isOptional()) - body << llvm::formatv("({0}() ? {0}().getType() : Type())", var->name); - else - body << var->name << "().getType()"; - } else if (auto *dir = dyn_cast(¶m)) { - auto *typeOperand = dir->getOperand(); - auto *operand = dyn_cast(typeOperand); - auto *var = operand ? operand->getVar() - : cast(typeOperand)->getVar(); - if (var->isVariadic()) - body << var->name << "().getTypes()"; - else if (var->isOptional()) - body << llvm::formatv("({0}() ? {0}().getType() : Type())", var->name); - else - body << var->name << "().getType()"; - } else { - llvm_unreachable("unknown custom directive parameter"); - } + genCustomDirectiveParameterPrinter(¶m, body); } - body << ");\n"; } @@ -1886,9 +1862,6 @@ } else if (auto *dir = dyn_cast(element)) { body << " p << "; genTypeOperandPrinter(dir->getOperand(), body) << ";\n"; - } else if (auto *dir = dyn_cast(element)) { - body << " p << "; - genTypeOperandPrinter(dir->getOperand(), body) << ";\n"; } else if (auto *dir = dyn_cast(element)) { body << " p.printFunctionalType("; genTypeOperandPrinter(dir->getInputs(), body) << ", "; @@ -1951,11 +1924,11 @@ kw_custom, kw_functional_type, kw_operands, + kw_ref, kw_regions, kw_results, kw_successors, kw_type, - kw_type_ref, keyword_end, // String valued tokens. @@ -2156,11 +2129,11 @@ .Case("custom", Token::kw_custom) .Case("functional-type", Token::kw_functional_type) .Case("operands", Token::kw_operands) + .Case("ref", Token::kw_ref) .Case("regions", Token::kw_regions) .Case("results", Token::kw_results) .Case("successors", Token::kw_successors) .Case("type", Token::kw_type) - .Case("type_ref", Token::kw_type_ref) .Default(Token::identifier); return Token(kind, str); } @@ -2191,6 +2164,19 @@ LogicalResult parse(); private: + /// The current context of the parser when parsing an element. + enum ParserContext { + /// The element is being parsed in a "top-level" context, i.e. at the top of + /// the format or in an optional group. + TopLevelContext, + /// The element is being parsed as a custom directive child. + CustomDirectiveContext, + /// The element is being parsed as a type directive child. + TypeDirectiveContext, + /// The element is being parsed as a reference directive child. + RefDirectiveContext + }; + /// This struct represents a type resolution instance. It includes a specific /// type as well as an optional transformer to apply to that type in order to /// properly resolve the type of a variable. @@ -2249,14 +2235,15 @@ /// Parse a specific element. LogicalResult parseElement(std::unique_ptr &element, - bool isTopLevel); + ParserContext context); LogicalResult parseVariable(std::unique_ptr &element, - bool isTopLevel); + ParserContext context); LogicalResult parseDirective(std::unique_ptr &element, - bool isTopLevel); - LogicalResult parseLiteral(std::unique_ptr &element); + ParserContext context); + LogicalResult parseLiteral(std::unique_ptr &element, + ParserContext context); LogicalResult parseOptional(std::unique_ptr &element, - bool isTopLevel); + ParserContext context); LogicalResult parseOptionalChildElement( std::vector> &childElements, Optional &anchorIdx); @@ -2265,26 +2252,29 @@ /// Parse the various different directives. LogicalResult parseAttrDictDirective(std::unique_ptr &element, - llvm::SMLoc loc, bool isTopLevel, + llvm::SMLoc loc, ParserContext context, bool withKeyword); LogicalResult parseCustomDirective(std::unique_ptr &element, - llvm::SMLoc loc, bool isTopLevel); + llvm::SMLoc loc, ParserContext context); LogicalResult parseCustomDirectiveParameter( std::vector> ¶meters); LogicalResult parseFunctionalTypeDirective(std::unique_ptr &element, - Token tok, bool isTopLevel); + Token tok, ParserContext context); LogicalResult parseOperandsDirective(std::unique_ptr &element, - llvm::SMLoc loc, bool isTopLevel); + llvm::SMLoc loc, ParserContext context); + LogicalResult parseReferenceDirective(std::unique_ptr &element, + llvm::SMLoc loc, ParserContext context); LogicalResult parseRegionsDirective(std::unique_ptr &element, - llvm::SMLoc loc, bool isTopLevel); + llvm::SMLoc loc, ParserContext context); LogicalResult parseResultsDirective(std::unique_ptr &element, - llvm::SMLoc loc, bool isTopLevel); + llvm::SMLoc loc, ParserContext context); LogicalResult parseSuccessorsDirective(std::unique_ptr &element, - llvm::SMLoc loc, bool isTopLevel); + llvm::SMLoc loc, + ParserContext context); LogicalResult parseTypeDirective(std::unique_ptr &element, Token tok, - bool isTopLevel, bool isTypeRef = false); + ParserContext context); LogicalResult parseTypeDirectiveOperand(std::unique_ptr &element, - bool isTypeRef = false); + bool isRefChild = false); //===--------------------------------------------------------------------===// // Lexer Utilities @@ -2340,7 +2330,7 @@ // Parse each of the format elements into the main format. while (curToken.getKind() != Token::eof) { std::unique_ptr element; - if (failed(parseElement(element, /*isTopLevel=*/true))) + if (failed(parseElement(element, TopLevelContext))) return ::mlir::failure(); fmt.elements.push_back(std::move(element)); } @@ -2634,25 +2624,25 @@ } LogicalResult FormatParser::parseElement(std::unique_ptr &element, - bool isTopLevel) { + ParserContext context) { // Directives. if (curToken.isKeyword()) - return parseDirective(element, isTopLevel); + return parseDirective(element, context); // Literals. if (curToken.getKind() == Token::literal) - return parseLiteral(element); + return parseLiteral(element, context); // Optionals. if (curToken.getKind() == Token::l_paren) - return parseOptional(element, isTopLevel); + return parseOptional(element, context); // Variables. if (curToken.getKind() == Token::variable) - return parseVariable(element, isTopLevel); + return parseVariable(element, context); return emitError(curToken.getLoc(), "expected directive, literal, variable, or optional group"); } LogicalResult FormatParser::parseVariable(std::unique_ptr &element, - bool isTopLevel) { + ParserContext context) { Token varTok = curToken; consumeToken(); @@ -2663,42 +2653,67 @@ // op. /// Attributes if (const NamedAttribute *attr = findArg(op.getAttributes(), name)) { - if (isTopLevel && !seenAttrs.insert(attr)) + if (context == TypeDirectiveContext) + return emitError( + loc, "attributes cannot be used as children to a `type` directive"); + if (context == RefDirectiveContext) { + if (!seenAttrs.count(attr)) + return emitError(loc, "attribute '" + name + + "' must be bound before it is referenced"); + } else if (!seenAttrs.insert(attr)) { return emitError(loc, "attribute '" + name + "' is already bound"); + } + element = std::make_unique(attr); return ::mlir::success(); } /// Operands if (const NamedTypeConstraint *operand = findArg(op.getOperands(), name)) { - if (isTopLevel) { + if (context == TopLevelContext || context == CustomDirectiveContext) { if (fmt.allOperands || !seenOperands.insert(operand).second) return emitError(loc, "operand '" + name + "' is already bound"); + } else if (context == RefDirectiveContext && !seenOperands.count(operand)) { + return emitError(loc, "operand '" + name + + "' must be bound before it is referenced"); } element = std::make_unique(operand); return ::mlir::success(); } /// Regions if (const NamedRegion *region = findArg(op.getRegions(), name)) { - if (!isTopLevel) + if (context == TopLevelContext || context == CustomDirectiveContext) { + if (hasAllRegions || !seenRegions.insert(region).second) + return emitError(loc, "region '" + name + "' is already bound"); + } else if (context == RefDirectiveContext && !seenRegions.count(region)) { + return emitError(loc, "region '" + name + + "' must be bound before it is referenced"); + } else { return emitError(loc, "regions can only be used at the top level"); - if (hasAllRegions || !seenRegions.insert(region).second) - return emitError(loc, "region '" + name + "' is already bound"); + } element = std::make_unique(region); return ::mlir::success(); } /// Results. if (const auto *result = findArg(op.getResults(), name)) { - if (isTopLevel) - return emitError(loc, "results can not be used at the top level"); + if (context != TypeDirectiveContext) + return emitError(loc, "result variables can can only be used as a child " + "to a 'type' directive"); element = std::make_unique(result); return ::mlir::success(); } /// Successors. if (const auto *successor = findArg(op.getSuccessors(), name)) { - if (!isTopLevel) + if (context == TopLevelContext || context == CustomDirectiveContext) { + if (hasAllSuccessors || !seenSuccessors.insert(successor).second) + return emitError(loc, "successor '" + name + "' is already bound"); + } else if (context == RefDirectiveContext && + !seenSuccessors.count(successor)) { + return emitError(loc, "successor '" + name + + "' must be bound before it is referenced"); + } else { return emitError(loc, "successors can only be used at the top level"); - if (hasAllSuccessors || !seenSuccessors.insert(successor).second) - return emitError(loc, "successor '" + name + "' is already bound"); + } + element = std::make_unique(successor); return ::mlir::success(); } @@ -2707,41 +2722,47 @@ } LogicalResult FormatParser::parseDirective(std::unique_ptr &element, - bool isTopLevel) { + ParserContext context) { Token dirTok = curToken; consumeToken(); switch (dirTok.getKind()) { case Token::kw_attr_dict: - return parseAttrDictDirective(element, dirTok.getLoc(), isTopLevel, + return parseAttrDictDirective(element, dirTok.getLoc(), context, /*withKeyword=*/false); case Token::kw_attr_dict_w_keyword: - return parseAttrDictDirective(element, dirTok.getLoc(), isTopLevel, + return parseAttrDictDirective(element, dirTok.getLoc(), context, /*withKeyword=*/true); case Token::kw_custom: - return parseCustomDirective(element, dirTok.getLoc(), isTopLevel); + return parseCustomDirective(element, dirTok.getLoc(), context); case Token::kw_functional_type: - return parseFunctionalTypeDirective(element, dirTok, isTopLevel); + return parseFunctionalTypeDirective(element, dirTok, context); case Token::kw_operands: - return parseOperandsDirective(element, dirTok.getLoc(), isTopLevel); + return parseOperandsDirective(element, dirTok.getLoc(), context); case Token::kw_regions: - return parseRegionsDirective(element, dirTok.getLoc(), isTopLevel); + return parseRegionsDirective(element, dirTok.getLoc(), context); case Token::kw_results: - return parseResultsDirective(element, dirTok.getLoc(), isTopLevel); + return parseResultsDirective(element, dirTok.getLoc(), context); case Token::kw_successors: - return parseSuccessorsDirective(element, dirTok.getLoc(), isTopLevel); - case Token::kw_type_ref: - return parseTypeDirective(element, dirTok, isTopLevel, /*isTypeRef=*/true); + return parseSuccessorsDirective(element, dirTok.getLoc(), context); + case Token::kw_ref: + return parseReferenceDirective(element, dirTok.getLoc(), context); case Token::kw_type: - return parseTypeDirective(element, dirTok, isTopLevel); + return parseTypeDirective(element, dirTok, context); default: llvm_unreachable("unknown directive token"); } } -LogicalResult FormatParser::parseLiteral(std::unique_ptr &element) { +LogicalResult FormatParser::parseLiteral(std::unique_ptr &element, + ParserContext context) { Token literalTok = curToken; + if (context != TopLevelContext) { + return emitError( + literalTok.getLoc(), + "literals may only be used in a top-level section of the format"); + } consumeToken(); StringRef value = literalTok.getSpelling().drop_front().drop_back(); @@ -2766,9 +2787,9 @@ } LogicalResult FormatParser::parseOptional(std::unique_ptr &element, - bool isTopLevel) { + ParserContext context) { llvm::SMLoc curLoc = curToken.getLoc(); - if (!isTopLevel) + if (context != TopLevelContext) return emitError(curLoc, "optional groups can only be used as top-level " "elements"); consumeToken(); @@ -2812,7 +2833,7 @@ Optional &anchorIdx) { llvm::SMLoc childLoc = curToken.getLoc(); childElements.push_back({}); - if (failed(parseElement(childElements.back(), /*isTopLevel=*/true))) + if (failed(parseElement(childElements.back(), TopLevelContext))) return ::mlir::failure(); // Check to see if this element is the anchor of the optional group. @@ -2843,7 +2864,7 @@ }) // Only optional-like(i.e. variadic) operands can be within an optional // group. - .Case([&](OperandVariable *ele) { + .Case([&](OperandVariable *ele) { if (!ele->getVar()->isVariableLength()) return emitError(childLoc, "only variable length operands can be " "used within an optional group"); @@ -2851,22 +2872,22 @@ }) // Only optional-like(i.e. variadic) results can be within an optional // group. - .Case([&](ResultVariable *ele) { + .Case([&](ResultVariable *ele) { if (!ele->getVar()->isVariableLength()) return emitError(childLoc, "only variable length results can be " "used within an optional group"); return ::mlir::success(); }) - .Case([&](RegionVariable *) { + .Case([&](RegionVariable *) { // TODO: When ODS has proper support for marking "optional" regions, add // a check here. return ::mlir::success(); }) - .Case([&](TypeDirective *ele) { + .Case([&](TypeDirective *ele) { return verifyOptionalChildElement(ele->getOperand(), childLoc, /*isAnchor=*/false); }) - .Case([&](FunctionalTypeDirective *ele) { + .Case([&](FunctionalTypeDirective *ele) { if (failed(verifyOptionalChildElement(ele->getInputs(), childLoc, /*isAnchor=*/false))) return failure(); @@ -2876,13 +2897,12 @@ // Literals, whitespace, and custom directives may be used, but they can't // anchor the group. .Case( - [&](Element *) { - if (isAnchor) - return emitError(childLoc, "only variables and types can be used " - "to anchor an optional group"); - return ::mlir::success(); - }) + FunctionalTypeDirective, OptionalElement>([&](Element *) { + if (isAnchor) + return emitError(childLoc, "only variables and types can be used " + "to anchor an optional group"); + return ::mlir::success(); + }) .Default([&](Element *) { return emitError(childLoc, "only literals, types, and variables can be " "used within an optional group"); @@ -2891,23 +2911,34 @@ LogicalResult FormatParser::parseAttrDictDirective(std::unique_ptr &element, - llvm::SMLoc loc, bool isTopLevel, + llvm::SMLoc loc, ParserContext context, bool withKeyword) { - if (!isTopLevel) + if (context == TypeDirectiveContext) return emitError(loc, "'attr-dict' directive can only be used as a " "top-level directive"); - if (hasAttrDict) - return emitError(loc, "'attr-dict' directive has already been seen"); - hasAttrDict = true; + if (context == RefDirectiveContext) { + if (!hasAttrDict) + return emitError(loc, "'ref' of 'attr-dict' is not bound by a prior " + "'attr-dict' directive"); + + // Otherwise, this is a top-level context. + } else { + if (hasAttrDict) + return emitError(loc, "'attr-dict' directive has already been seen"); + hasAttrDict = true; + } + element = std::make_unique(withKeyword); return ::mlir::success(); } LogicalResult FormatParser::parseCustomDirective(std::unique_ptr &element, - llvm::SMLoc loc, bool isTopLevel) { + llvm::SMLoc loc, ParserContext context) { llvm::SMLoc curLoc = curToken.getLoc(); + if (context != TopLevelContext) + return emitError(loc, "'custom' is only valid as a top-level directive"); // Parse the custom directive name. if (failed( @@ -2940,13 +2971,6 @@ // After parsing all of the elements, ensure that all type directives refer // only to variables. for (auto &ele : elements) { - if (auto *typeEle = dyn_cast(ele.get())) { - if (!isa(typeEle->getOperand())) { - return emitError(curLoc, - "type_ref directives within a custom directive " - "may only refer to variables"); - } - } if (auto *typeEle = dyn_cast(ele.get())) { if (!isa(typeEle->getOperand())) { return emitError(curLoc, "type directives within a custom directive " @@ -2964,13 +2988,13 @@ std::vector> ¶meters) { llvm::SMLoc childLoc = curToken.getLoc(); parameters.push_back({}); - if (failed(parseElement(parameters.back(), /*isTopLevel=*/true))) + if (failed(parseElement(parameters.back(), CustomDirectiveContext))) return ::mlir::failure(); // Verify that the element can be placed within a custom directive. - if (!isa(parameters.back().get())) { + if (!isa( + parameters.back().get())) { return emitError(childLoc, "only variables and types may be used as " "parameters to a custom directive"); } @@ -2979,9 +3003,9 @@ LogicalResult FormatParser::parseFunctionalTypeDirective(std::unique_ptr &element, - Token tok, bool isTopLevel) { + Token tok, ParserContext context) { llvm::SMLoc loc = tok.getLoc(); - if (!isTopLevel) + if (context != TopLevelContext) return emitError( loc, "'functional-type' is only valid as a top-level directive"); @@ -3000,8 +3024,13 @@ LogicalResult FormatParser::parseOperandsDirective(std::unique_ptr &element, - llvm::SMLoc loc, bool isTopLevel) { - if (isTopLevel) { + llvm::SMLoc loc, ParserContext context) { + if (context == RefDirectiveContext) { + if (!fmt.allOperands) + return emitError(loc, "'ref' of 'operands' is not bound by a prior " + "'operands' directive"); + + } else if (context == TopLevelContext || context == CustomDirectiveContext) { if (fmt.allOperands || !seenOperands.empty()) return emitError(loc, "'operands' directive creates overlap in format"); fmt.allOperands = true; @@ -3010,65 +3039,96 @@ return ::mlir::success(); } +LogicalResult +FormatParser::parseReferenceDirective(std::unique_ptr &element, + llvm::SMLoc loc, ParserContext context) { + if (context != CustomDirectiveContext) + return emitError(loc, "'ref' is only valid within a `custom` directive"); + + std::unique_ptr operand; + if (failed(parseToken(Token::l_paren, "expected '(' before argument list")) || + failed(parseElement(operand, RefDirectiveContext)) || + failed(parseToken(Token::r_paren, "expected ')' after argument list"))) + return ::mlir::failure(); + + element = std::make_unique(std::move(operand)); + return ::mlir::success(); +} + LogicalResult FormatParser::parseRegionsDirective(std::unique_ptr &element, - llvm::SMLoc loc, bool isTopLevel) { - if (!isTopLevel) + llvm::SMLoc loc, ParserContext context) { + if (context == TypeDirectiveContext) return emitError(loc, "'regions' is only valid as a top-level directive"); - if (hasAllRegions || !seenRegions.empty()) - return emitError(loc, "'regions' directive creates overlap in format"); - hasAllRegions = true; + if (context == RefDirectiveContext) { + if (!hasAllRegions) + return emitError(loc, "'ref' of 'regions' is not bound by a prior " + "'regions' directive"); + + // Otherwise, this is a TopLevel directive. + } else { + if (hasAllRegions || !seenRegions.empty()) + return emitError(loc, "'regions' directive creates overlap in format"); + hasAllRegions = true; + } element = std::make_unique(); return ::mlir::success(); } LogicalResult FormatParser::parseResultsDirective(std::unique_ptr &element, - llvm::SMLoc loc, bool isTopLevel) { - if (isTopLevel) - return emitError(loc, "'results' directive can not be used as a " - "top-level directive"); + llvm::SMLoc loc, ParserContext context) { + if (context != TypeDirectiveContext) + return emitError(loc, "'results' directive can can only be used as a child " + "to a 'type' directive"); element = std::make_unique(); return ::mlir::success(); } LogicalResult FormatParser::parseSuccessorsDirective(std::unique_ptr &element, - llvm::SMLoc loc, bool isTopLevel) { - if (!isTopLevel) + llvm::SMLoc loc, ParserContext context) { + if (context == TypeDirectiveContext) return emitError(loc, "'successors' is only valid as a top-level directive"); - if (hasAllSuccessors || !seenSuccessors.empty()) - return emitError(loc, "'successors' directive creates overlap in format"); - hasAllSuccessors = true; + if (context == RefDirectiveContext) { + if (!hasAllSuccessors) + return emitError(loc, "'ref' of 'successors' is not bound by a prior " + "'successors' directive"); + + // Otherwise, this is a TopLevel directive. + } else { + if (hasAllSuccessors || !seenSuccessors.empty()) + return emitError(loc, "'successors' directive creates overlap in format"); + hasAllSuccessors = true; + } element = std::make_unique(); return ::mlir::success(); } LogicalResult FormatParser::parseTypeDirective(std::unique_ptr &element, Token tok, - bool isTopLevel, bool isTypeRef) { + ParserContext context) { llvm::SMLoc loc = tok.getLoc(); - if (!isTopLevel) - return emitError(loc, "'type' is only valid as a top-level directive"); + if (context == TypeDirectiveContext) + return emitError(loc, "'type' cannot be used as a child of another `type`"); + bool isRefChild = context == RefDirectiveContext; std::unique_ptr operand; if (failed(parseToken(Token::l_paren, "expected '(' before argument list")) || - failed(parseTypeDirectiveOperand(operand, isTypeRef)) || + failed(parseTypeDirectiveOperand(operand, isRefChild)) || failed(parseToken(Token::r_paren, "expected ')' after argument list"))) return ::mlir::failure(); - if (isTypeRef) - element = std::make_unique(std::move(operand)); - else - element = std::make_unique(std::move(operand)); + + element = std::make_unique(std::move(operand)); return ::mlir::success(); } LogicalResult FormatParser::parseTypeDirectiveOperand(std::unique_ptr &element, - bool isTypeRef) { + bool isRefChild) { llvm::SMLoc loc = curToken.getLoc(); - if (failed(parseElement(element, /*isTopLevel=*/false))) + if (failed(parseElement(element, TypeDirectiveContext))) return ::mlir::failure(); if (isa(element.get())) return emitError( @@ -3076,36 +3136,35 @@ if (auto *var = dyn_cast(element.get())) { unsigned opIdx = var->getVar() - op.operand_begin(); - if (!isTypeRef && (fmt.allOperandTypes || seenOperandTypes.test(opIdx))) + if (!isRefChild && (fmt.allOperandTypes || seenOperandTypes.test(opIdx))) return emitError(loc, "'type' of '" + var->getVar()->name + "' is already bound"); - if (isTypeRef && !(fmt.allOperandTypes || seenOperandTypes.test(opIdx))) - return emitError(loc, "'type_ref' of '" + var->getVar()->name + - "' is not bound by a prior 'type' directive"); + if (isRefChild && !(fmt.allOperandTypes || seenOperandTypes.test(opIdx))) + return emitError(loc, "'ref' of 'type($" + var->getVar()->name + + ")' is not bound by a prior 'type' directive"); seenOperandTypes.set(opIdx); } else if (auto *var = dyn_cast(element.get())) { unsigned resIdx = var->getVar() - op.result_begin(); - if (!isTypeRef && (fmt.allResultTypes || seenResultTypes.test(resIdx))) + if (!isRefChild && (fmt.allResultTypes || seenResultTypes.test(resIdx))) return emitError(loc, "'type' of '" + var->getVar()->name + "' is already bound"); - if (isTypeRef && !(fmt.allResultTypes || seenResultTypes.test(resIdx))) - return emitError(loc, "'type_ref' of '" + var->getVar()->name + - "' is not bound by a prior 'type' directive"); + if (isRefChild && !(fmt.allResultTypes || seenResultTypes.test(resIdx))) + return emitError(loc, "'ref' of 'type($" + var->getVar()->name + + ")' is not bound by a prior 'type' directive"); seenResultTypes.set(resIdx); } else if (isa(&*element)) { - if (!isTypeRef && (fmt.allOperandTypes || seenOperandTypes.any())) + if (!isRefChild && (fmt.allOperandTypes || seenOperandTypes.any())) return emitError(loc, "'operands' 'type' is already bound"); - if (isTypeRef && !(fmt.allOperandTypes || seenOperandTypes.all())) - return emitError( - loc, - "'operands' 'type_ref' is not bound by a prior 'type' directive"); + if (isRefChild && !fmt.allOperandTypes) + return emitError(loc, "'ref' of 'type(operands)' is not bound by a prior " + "'type' directive"); fmt.allOperandTypes = true; } else if (isa(&*element)) { - if (!isTypeRef && (fmt.allResultTypes || seenResultTypes.any())) + if (!isRefChild && (fmt.allResultTypes || seenResultTypes.any())) return emitError(loc, "'results' 'type' is already bound"); - if (isTypeRef && !(fmt.allResultTypes || seenResultTypes.all())) - return emitError( - loc, "'results' 'type_ref' is not bound by a prior 'type' directive"); + if (isRefChild && !fmt.allResultTypes) + return emitError(loc, "'ref' of 'type(results)' is not bound by a prior " + "'type' directive"); fmt.allResultTypes = true; } else { return emitError(loc, "invalid argument to 'type' directive");