diff --git a/mlir/test/mlir-tblgen/op-format-invalid.td b/mlir/test/mlir-tblgen/op-format-invalid.td --- a/mlir/test/mlir-tblgen/op-format-invalid.td +++ b/mlir/test/mlir-tblgen/op-format-invalid.td @@ -403,6 +403,13 @@ ($arg^):(`test`) }]>, Arguments<(ins Variadic:$arg)>; +//===----------------------------------------------------------------------===// +// Strings +//===----------------------------------------------------------------------===// + +// CHECK: error: strings may only be used as 'custom' directive arguments +def StringInvalidA : TestFormat_Op<[{ "foo" }]>; + //===----------------------------------------------------------------------===// // Variables //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td --- a/mlir/test/mlir-tblgen/op-format-spec.td +++ b/mlir/test/mlir-tblgen/op-format-spec.td @@ -135,6 +135,13 @@ (` ` `` $arg^)? attr-dict }]>, Arguments<(ins Optional:$arg)>; +//===----------------------------------------------------------------------===// +// Strings +//===----------------------------------------------------------------------===// + +// CHECK-NOT: error +def StringInvalidA : TestFormat_Op<[{ custom("foo") }]>; + //===----------------------------------------------------------------------===// // Variables //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-tblgen/op-format.td b/mlir/test/mlir-tblgen/op-format.td new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-tblgen/op-format.td @@ -0,0 +1,34 @@ +// RUN: mlir-tblgen -gen-op-defs -I %S/../../include %s | FileCheck %s + +include "mlir/IR/OpBase.td" + +def TestDialect : Dialect { + let name = "test"; +} +class TestFormat_Op traits = []> + : Op { + let assemblyFormat = fmt; +} + +//===----------------------------------------------------------------------===// +// Directives +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// custom + +// CHECK-LABEL: CustomStringLiteralA::parse +// CHECK: parseFoo({{.*}}, parser.getBuilder().getI1Type()) +// CHECK-LABEL: CustomStringLiteralA::print +// CHECK: printFoo({{.*}}, parser.getBuilder().getI1Type()) +def CustomStringLiteralA : TestFormat_Op<[{ + custom("$_builder.getI1Type()") attr-dict +}]>; + +// CHECK-LABEL: CustomStringLiteralB::parse +// CHECK: parseFoo({{.*}}, IndexType::get(parser.getContext())) +// CHECK-LABEL: CustomStringLiteralB::print +// CHECK: printFoo({{.*}}, IndexType::get(parser.getContext())) +def CustomStringLiteralB : TestFormat_Op<[{ + custom("IndexType::get($_ctxt)") attr-dict +}]>; diff --git a/mlir/tools/mlir-tblgen/FormatGen.h b/mlir/tools/mlir-tblgen/FormatGen.h --- a/mlir/tools/mlir-tblgen/FormatGen.h +++ b/mlir/tools/mlir-tblgen/FormatGen.h @@ -78,6 +78,7 @@ identifier, literal, variable, + string, }; FormatToken(Kind kind, StringRef spelling) : kind(kind), spelling(spelling) {} @@ -130,10 +131,11 @@ /// Return the next character in the stream. int getNextChar(); - /// Lex an identifier, literal, or variable. + /// Lex an identifier, literal, variable, or string. FormatToken lexIdentifier(const char *tokStart); FormatToken lexLiteral(const char *tokStart); FormatToken lexVariable(const char *tokStart); + FormatToken lexString(const char *tokStart); /// Create a token with the current pointer and a start pointer. FormatToken formToken(FormatToken::Kind kind, const char *tokStart) { @@ -163,7 +165,7 @@ virtual ~FormatElement(); // The top-level kinds of format elements. - enum Kind { Literal, Variable, Whitespace, Directive, Optional }; + enum Kind { Literal, String, Variable, Whitespace, Directive, Optional }; /// Support LLVM-style RTTI. static bool classof(const FormatElement *el) { return true; } @@ -212,6 +214,20 @@ StringRef spelling; }; +/// This class represents a raw string that can contain arbitrary C++ code. +class StringElement : public FormatElementBase { +public: + /// Create a string element with the given contents. + explicit StringElement(StringRef value) : value(value) {} + + /// Get the value of the string element. + StringRef getValue() const { return value; } + +private: + /// The contents of the string. + StringRef value; +}; + /// This class represents a variable element. A variable refers to some part of /// the object being parsed, e.g. an attribute or operand on an operation or a /// parameter on an attribute. @@ -447,6 +463,8 @@ FailureOr parseElement(Context ctx); /// Parse a literal. FailureOr parseLiteral(Context ctx); + /// Parse a string. + FailureOr parseString(Context ctx); /// Parse a variable. FailureOr parseVariable(Context ctx); /// Parse a directive. diff --git a/mlir/tools/mlir-tblgen/FormatGen.cpp b/mlir/tools/mlir-tblgen/FormatGen.cpp --- a/mlir/tools/mlir-tblgen/FormatGen.cpp +++ b/mlir/tools/mlir-tblgen/FormatGen.cpp @@ -129,6 +129,8 @@ return lexLiteral(tokStart); case '$': return lexVariable(tokStart); + case '"': + return lexString(tokStart); } } @@ -153,6 +155,17 @@ return formToken(FormatToken::variable, tokStart); } +FormatToken FormatLexer::lexString(const char *tokStart) { + // Lex until another quote, respecting escapes. + bool escape = false; + while (const char curChar = *curPtr++) { + if (!escape && curChar == '"') + return formToken(FormatToken::string, tokStart); + escape = curChar == '\\'; + } + return emitError(curPtr - 1, "unexpected end of file in string"); +} + FormatToken FormatLexer::lexIdentifier(const char *tokStart) { // Match the rest of the identifier regex: [0-9a-zA-Z_\-]* while (isalnum(*curPtr) || *curPtr == '_' || *curPtr == '-') @@ -212,6 +225,8 @@ FailureOr FormatParser::parseElement(Context ctx) { if (curToken.is(FormatToken::literal)) return parseLiteral(ctx); + if (curToken.is(FormatToken::string)) + return parseString(ctx); if (curToken.is(FormatToken::variable)) return parseVariable(ctx); if (curToken.isKeyword()) @@ -253,6 +268,18 @@ return create(value); } +FailureOr FormatParser::parseString(Context ctx) { + FormatToken tok = curToken; + SMLoc loc = tok.getLoc(); + consumeToken(); + + if (ctx != CustomDirectiveContext) { + return emitError( + loc, "strings may only be used as 'custom' directive arguments"); + } + return create(tok.getSpelling().drop_front().drop_back()); +} + FailureOr FormatParser::parseVariable(Context ctx) { FormatToken tok = curToken; SMLoc loc = tok.getLoc(); diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -916,6 +916,13 @@ body << llvm::formatv("{0}Type", listName); else body << formatv("{0}RawTypes[0]", listName); + + } else if (auto *string = dyn_cast(param)) { + FmtContext ctx; + ctx.withBuilder("parser.getBuilder()"); + ctx.addSubst("_ctxt", "parser.getContext()"); + body << tgfmt(string->getValue(), &ctx); + } else { llvm_unreachable("unknown custom directive parameter"); } @@ -1715,6 +1722,13 @@ body << llvm::formatv("({0}() ? {0}().getType() : Type())", name); else body << name << "().getType()"; + + } else if (auto *string = dyn_cast(element)) { + FmtContext ctx; + ctx.withBuilder("parser.getBuilder()"); + ctx.addSubst("_ctxt", "parser.getContext()"); + body << tgfmt(string->getValue(), &ctx); + } else { llvm_unreachable("unknown custom directive parameter"); } @@ -2826,8 +2840,9 @@ LogicalResult OpFormatParser::verifyCustomDirectiveArguments( SMLoc loc, ArrayRef arguments) { for (FormatElement *argument : arguments) { - if (!isa(argument)) { + if (!isa(argument)) { // TODO: FormatElement should have location info attached. return emitError(loc, "only variables and types may be used as " "parameters to a custom directive");