diff --git a/mlir/docs/AttributesAndTypes.md b/mlir/docs/AttributesAndTypes.md --- a/mlir/docs/AttributesAndTypes.md +++ b/mlir/docs/AttributesAndTypes.md @@ -895,6 +895,19 @@ The custom parser is considered to have failed if it returns failure or if any bound parameters have failure values afterwards. +A string of C++ code can be used as a `custom` directive argument. When +generating the custom parser and printer call, the string is pasted as a +function argument. For example, `parseBar` and `printBar` can be re-used with +a constant integer: + +```tablegen +let parameters = (ins "int":$bar); +let assemblyFormat = [{ custom($foo, "1") }]; +``` + +The string is pasted verbatim but with substitutions for `$_builder` and +`$_ctxt`. String literals can be used to parameterize custom directives. + ### Verification If the `genVerifyDecl` field is set, additional verification methods are diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md --- a/mlir/docs/OpDefinitions.md +++ b/mlir/docs/OpDefinitions.md @@ -768,9 +768,9 @@ identifier used as a suffix to these two calls, i.e., `custom(...)` would result in calls to `parseMyDirective` and `printMyDirective` within the parser and printer respectively. `Params` may be any combination of variables -(i.e. Attribute, Operand, Successor, etc.), type directives, and `attr-dict`. -The type directives must refer to a variable, but that variable need not also be -a parameter to the custom directive. +(i.e. Attribute, Operand, Successor, etc.), type directives, `attr-dict`, and +strings of C++ code. The type directives must refer to a variable, but that +variable need not also be a parameter to the custom directive. The arguments to the `parse` method are firstly a reference to the `OpAsmParser`(`OpAsmParser &`), and secondly a set of output parameters @@ -837,7 +837,16 @@ - VariadicOfVariadic: `TypeRangeRange` * `attr-dict` Directive: `DictionaryAttr` -When a variable is optional, the provided value may be null. +When a variable is optional, the provided value may be null. When a variable is +referenced in a custom directive parameter using `ref`, it is passed in by +value. Referenced variables to `print` are passed as the same as +bound variables, but referenced variables to `parse` are passed +like to the printer. + +A custom directive can take a string of C++ code as a parameter. The code is +pasted verbatim in the calls to the custom parser and printers, with the +substitutions `$_builder` and `$_ctxt`. String literals can be used to +parameterize custom directives. #### Optional Groups @@ -1462,7 +1471,7 @@ if (2u == (2u & val)) { strs.push_back("Bit1"); } if (4u == (4u & val)) { strs.push_back("Bit2"); } if (8u == (8u & val)) { strs.push_back("Bit3"); } - + return llvm::join(strs, "|"); } diff --git a/mlir/test/mlir-tblgen/attr-or-type-format.td b/mlir/test/mlir-tblgen/attr-or-type-format.td --- a/mlir/test/mlir-tblgen/attr-or-type-format.td +++ b/mlir/test/mlir-tblgen/attr-or-type-format.td @@ -555,3 +555,19 @@ let mnemonic = "type_k"; let assemblyFormat = "$a"; } + +// TYPE-LABEL: ::mlir::Type TestNType::parse +// TYPE: parseFoo( +// TYPE-NEXT: _result_a, +// TYPE-NEXT: 1); + +// TYPE-LABEL: void TestNType::print +// TYPE: printFoo( +// TYPE-NEXT: getA(), +// TYPE-NEXT: 1); + +def TypeL : TestType<"TestN"> { + let parameters = (ins "int":$a); + let mnemonic = "type_l"; + let assemblyFormat = [{ custom($a, "1") }]; +} diff --git a/mlir/test/mlir-tblgen/op-format-invalid.td b/mlir/test/mlir-tblgen/op-format-invalid.td --- a/mlir/test/mlir-tblgen/op-format-invalid.td +++ b/mlir/test/mlir-tblgen/op-format-invalid.td @@ -403,6 +403,13 @@ ($arg^):(`test`) }]>, Arguments<(ins Variadic:$arg)>; +//===----------------------------------------------------------------------===// +// Strings +//===----------------------------------------------------------------------===// + +// CHECK: error: strings may only be used as 'custom' directive arguments +def StringInvalidA : TestFormat_Op<[{ "foo" }]>; + //===----------------------------------------------------------------------===// // Variables //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td --- a/mlir/test/mlir-tblgen/op-format-spec.td +++ b/mlir/test/mlir-tblgen/op-format-spec.td @@ -135,6 +135,13 @@ (` ` `` $arg^)? attr-dict }]>, Arguments<(ins Optional:$arg)>; +//===----------------------------------------------------------------------===// +// Strings +//===----------------------------------------------------------------------===// + +// CHECK-NOT: error +def StringInvalidA : TestFormat_Op<[{ custom("foo") attr-dict }]>; + //===----------------------------------------------------------------------===// // Variables //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-tblgen/op-format.td b/mlir/test/mlir-tblgen/op-format.td new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-tblgen/op-format.td @@ -0,0 +1,34 @@ +// RUN: mlir-tblgen -gen-op-defs -I %S/../../include %s | FileCheck %s + +include "mlir/IR/OpBase.td" + +def TestDialect : Dialect { + let name = "test"; +} +class TestFormat_Op traits = []> + : Op { + let assemblyFormat = fmt; +} + +//===----------------------------------------------------------------------===// +// Directives +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// custom + +// CHECK-LABEL: CustomStringLiteralA::parse +// CHECK: parseFoo({{.*}}, parser.getBuilder().getI1Type()) +// CHECK-LABEL: CustomStringLiteralA::print +// CHECK: printFoo({{.*}}, parser.getBuilder().getI1Type()) +def CustomStringLiteralA : TestFormat_Op<[{ + custom("$_builder.getI1Type()") attr-dict +}]>; + +// CHECK-LABEL: CustomStringLiteralB::parse +// CHECK: parseFoo({{.*}}, IndexType::get(parser.getContext())) +// CHECK-LABEL: CustomStringLiteralB::print +// CHECK: printFoo({{.*}}, IndexType::get(parser.getContext())) +def CustomStringLiteralB : TestFormat_Op<[{ + custom("IndexType::get($_ctxt)") attr-dict +}]>; diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp @@ -629,14 +629,12 @@ os.indent(); for (FormatElement *arg : el->getArguments()) { os << ",\n"; - FormatElement *param; - if (auto *ref = dyn_cast(arg)) { - os << "*"; - param = ref->getArg(); - } else { - param = arg; - } - os << "_result_" << cast(param)->getName(); + if (auto *param = dyn_cast(arg)) + os << "_result_" << param->getName(); + else if (auto *ref = dyn_cast(arg)) + os << "*_result_" << cast(ref->getArg())->getName(); + else + os << tgfmt(cast(arg)->getValue(), &ctx); } os.unindent() << ");\n"; os << "if (::mlir::failed(odsCustomResult)) return {};\n"; @@ -845,11 +843,15 @@ os << tgfmt("print$0($_printer", &ctx, el->getName()); os.indent(); for (FormatElement *arg : el->getArguments()) { - FormatElement *param = arg; - if (auto *ref = dyn_cast(arg)) - param = ref->getArg(); - os << ",\n" - << cast(param)->getParam().getAccessorName() << "()"; + os << ",\n"; + if (auto *param = dyn_cast(arg)) { + os << param->getParam().getAccessorName() << "()"; + } else if (auto *ref = dyn_cast(arg)) { + os << cast(ref->getArg())->getParam().getAccessorName() + << "()"; + } else { + os << tgfmt(cast(arg)->getValue(), &ctx); + } } os.unindent() << ");\n"; } diff --git a/mlir/tools/mlir-tblgen/FormatGen.h b/mlir/tools/mlir-tblgen/FormatGen.h --- a/mlir/tools/mlir-tblgen/FormatGen.h +++ b/mlir/tools/mlir-tblgen/FormatGen.h @@ -78,6 +78,7 @@ identifier, literal, variable, + string, }; FormatToken(Kind kind, StringRef spelling) : kind(kind), spelling(spelling) {} @@ -130,10 +131,11 @@ /// Return the next character in the stream. int getNextChar(); - /// Lex an identifier, literal, or variable. + /// Lex an identifier, literal, variable, or string. FormatToken lexIdentifier(const char *tokStart); FormatToken lexLiteral(const char *tokStart); FormatToken lexVariable(const char *tokStart); + FormatToken lexString(const char *tokStart); /// Create a token with the current pointer and a start pointer. FormatToken formToken(FormatToken::Kind kind, const char *tokStart) { @@ -163,7 +165,7 @@ virtual ~FormatElement(); // The top-level kinds of format elements. - enum Kind { Literal, Variable, Whitespace, Directive, Optional }; + enum Kind { Literal, String, Variable, Whitespace, Directive, Optional }; /// Support LLVM-style RTTI. static bool classof(const FormatElement *el) { return true; } @@ -212,6 +214,20 @@ StringRef spelling; }; +/// This class represents a raw string that can contain arbitrary C++ code. +class StringElement : public FormatElementBase { +public: + /// Create a string element with the given contents. + explicit StringElement(StringRef value) : value(value) {} + + /// Get the value of the string element. + StringRef getValue() const { return value; } + +private: + /// The contents of the string. + StringRef value; +}; + /// This class represents a variable element. A variable refers to some part of /// the object being parsed, e.g. an attribute or operand on an operation or a /// parameter on an attribute. @@ -447,6 +463,8 @@ FailureOr parseElement(Context ctx); /// Parse a literal. FailureOr parseLiteral(Context ctx); + /// Parse a string. + FailureOr parseString(Context ctx); /// Parse a variable. FailureOr parseVariable(Context ctx); /// Parse a directive. diff --git a/mlir/tools/mlir-tblgen/FormatGen.cpp b/mlir/tools/mlir-tblgen/FormatGen.cpp --- a/mlir/tools/mlir-tblgen/FormatGen.cpp +++ b/mlir/tools/mlir-tblgen/FormatGen.cpp @@ -129,6 +129,8 @@ return lexLiteral(tokStart); case '$': return lexVariable(tokStart); + case '"': + return lexString(tokStart); } } @@ -153,6 +155,17 @@ return formToken(FormatToken::variable, tokStart); } +FormatToken FormatLexer::lexString(const char *tokStart) { + // Lex until another quote, respecting escapes. + bool escape = false; + while (const char curChar = *curPtr++) { + if (!escape && curChar == '"') + return formToken(FormatToken::string, tokStart); + escape = curChar == '\\'; + } + return emitError(curPtr - 1, "unexpected end of file in string"); +} + FormatToken FormatLexer::lexIdentifier(const char *tokStart) { // Match the rest of the identifier regex: [0-9a-zA-Z_\-]* while (isalnum(*curPtr) || *curPtr == '_' || *curPtr == '-') @@ -212,6 +225,8 @@ FailureOr FormatParser::parseElement(Context ctx) { if (curToken.is(FormatToken::literal)) return parseLiteral(ctx); + if (curToken.is(FormatToken::string)) + return parseString(ctx); if (curToken.is(FormatToken::variable)) return parseVariable(ctx); if (curToken.isKeyword()) @@ -253,6 +268,18 @@ return create(value); } +FailureOr FormatParser::parseString(Context ctx) { + FormatToken tok = curToken; + SMLoc loc = tok.getLoc(); + consumeToken(); + + if (ctx != CustomDirectiveContext) { + return emitError( + loc, "strings may only be used as 'custom' directive arguments"); + } + return create(tok.getSpelling().drop_front().drop_back()); +} + FailureOr FormatParser::parseVariable(Context ctx) { FormatToken tok = curToken; SMLoc loc = tok.getLoc(); diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -916,6 +916,13 @@ body << llvm::formatv("{0}Type", listName); else body << formatv("{0}RawTypes[0]", listName); + + } else if (auto *string = dyn_cast(param)) { + FmtContext ctx; + ctx.withBuilder("parser.getBuilder()"); + ctx.addSubst("_ctxt", "parser.getContext()"); + body << tgfmt(string->getValue(), &ctx); + } else { llvm_unreachable("unknown custom directive parameter"); } @@ -1715,6 +1722,13 @@ body << llvm::formatv("({0}() ? {0}().getType() : Type())", name); else body << name << "().getType()"; + + } else if (auto *string = dyn_cast(element)) { + FmtContext ctx; + ctx.withBuilder("parser.getBuilder()"); + ctx.addSubst("_ctxt", "parser.getContext()"); + body << tgfmt(string->getValue(), &ctx); + } else { llvm_unreachable("unknown custom directive parameter"); } @@ -2826,8 +2840,9 @@ LogicalResult OpFormatParser::verifyCustomDirectiveArguments( SMLoc loc, ArrayRef arguments) { for (FormatElement *argument : arguments) { - if (!isa(argument)) { + if (!isa(argument)) { // TODO: FormatElement should have location info attached. return emitError(loc, "only variables and types may be used as " "parameters to a custom directive");