diff --git a/mlir/docs/AttributesAndTypes.md b/mlir/docs/AttributesAndTypes.md --- a/mlir/docs/AttributesAndTypes.md +++ b/mlir/docs/AttributesAndTypes.md @@ -895,6 +895,19 @@ The custom parser is considered to have failed if it returns failure or if any bound parameters have failure values afterwards. +A string of C++ code can be used as a `custom` directive argument. When +generating the custom parser and printer call, the string is pasted as a +function argument. For example, `parseBar` and `printBar` can be re-used with +a constant integer: + +```tablegen +let parameters = (ins "int":$bar); +let assemblyFormat = [{ custom($foo, "1") }]; +``` + +The string is pasted verbatim but with substitutions for `$_builder` and +`$_ctxt`. String literals can be used to parameterize custom directives. + ### Verification If the `genVerifyDecl` field is set, additional verification methods are diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md --- a/mlir/docs/OpDefinitions.md +++ b/mlir/docs/OpDefinitions.md @@ -768,9 +768,9 @@ identifier used as a suffix to these two calls, i.e., `custom(...)` would result in calls to `parseMyDirective` and `printMyDirective` within the parser and printer respectively. `Params` may be any combination of variables -(i.e. Attribute, Operand, Successor, etc.), type directives, and `attr-dict`. -The type directives must refer to a variable, but that variable need not also be -a parameter to the custom directive. +(i.e. Attribute, Operand, Successor, etc.), type directives, `attr-dict`, and +strings of C++ code. The type directives must refer to a variable, but that +variable need not also be a parameter to the custom directive. The arguments to the `parse` method are firstly a reference to the `OpAsmParser`(`OpAsmParser &`), and secondly a set of output parameters @@ -837,7 +837,16 @@ - VariadicOfVariadic: `TypeRangeRange` * `attr-dict` Directive: `DictionaryAttr` -When a variable is optional, the provided value may be null. +When a variable is optional, the provided value may be null. When a variable is +referenced in a custom directive parameter using `ref`, it is passed in by +value. Referenced variables to `print` are passed as the same as +bound variables, but referenced variables to `parse` are passed +like to the printer. + +A custom directive can take a string of C++ code as a parameter. The code is +pasted verbatim in the calls to the custom parser and printers, with the +substitutions `$_builder` and `$_ctxt`. String literals can be used to +parameterize custom directives. #### Optional Groups @@ -1462,7 +1471,7 @@ if (2u == (2u & val)) { strs.push_back("Bit1"); } if (4u == (4u & val)) { strs.push_back("Bit2"); } if (8u == (8u & val)) { strs.push_back("Bit3"); } - + return llvm::join(strs, "|"); } diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -45,8 +45,9 @@ let results = (outs AnyTensor:$result); let assemblyFormat = [{ - custom($sizes, $static_sizes) attr-dict - `:` type($result) + custom($sizes, $static_sizes, + "ShapedType::kDynamicSize") + attr-dict `:` type($result) }]; let extraClassDeclaration = [{ diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -1023,11 +1023,14 @@ let assemblyFormat = [{ $source `to` `offset` `` `:` - custom($offsets, $static_offsets) + custom($offsets, $static_offsets, + "ShapedType::kDynamicStrideOrOffset") `` `,` `sizes` `` `:` - custom($sizes, $static_sizes) `` `,` `strides` - `` `:` - custom($strides, $static_strides) + custom($sizes, $static_sizes, + "ShapedType::kDynamicSize") + `` `,` `strides` `` `:` + custom($strides, $static_strides, + "ShapedType::kDynamicStrideOrOffset") attr-dict `:` type($source) `to` type($result) }]; @@ -1586,9 +1589,12 @@ let assemblyFormat = [{ $source `` - custom($offsets, $static_offsets) - custom($sizes, $static_sizes) - custom($strides, $static_strides) + custom($offsets, $static_offsets, + "ShapedType::kDynamicStrideOrOffset") + custom($sizes, $static_sizes, + "ShapedType::kDynamicSize") + custom($strides, $static_strides, + "ShapedType::kDynamicStrideOrOffset") attr-dict `:` type($source) `to` type($result) }]; diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -219,11 +219,11 @@ To disambiguate, the inference helpers `inferCanonicalRankReducedResultType` only drop the first unit dimensions, in order: e.g. 1x6x1 rank-reduced to 2-D will infer the 6x1 2-D shape, but not 1x6. - + Verification however has access to result type and does not need to infer. - The verifier calls `isRankReducedType(getSource(), getResult())` to + The verifier calls `isRankReducedType(getSource(), getResult())` to determine whether the result type is rank-reduced from the source type. - This computes a so-called rank-reduction mask, consisting of dropped unit + This computes a so-called rank-reduction mask, consisting of dropped unit dims, to map the rank-reduced type to the source type by dropping ones: e.g. 1x6 is a rank-reduced version of 1x6x1 by mask {2} 6x1 is a rank-reduced version of 1x6x1 by mask {0} @@ -254,9 +254,12 @@ let assemblyFormat = [{ $source `` - custom($offsets, $static_offsets) - custom($sizes, $static_sizes) - custom($strides, $static_strides) + custom($offsets, $static_offsets, + "ShapedType::kDynamicStrideOrOffset") + custom($sizes, $static_sizes, + "ShapedType::kDynamicSize") + custom($strides, $static_strides, + "ShapedType::kDynamicStrideOrOffset") attr-dict `:` type($source) `to` type($result) }]; @@ -298,12 +301,12 @@ /// tensor type to the result tensor type by dropping unit dims. llvm::Optional> computeRankReductionMask() { - return ::mlir::computeRankReductionMask(getSourceType().getShape(), + return ::mlir::computeRankReductionMask(getSourceType().getShape(), getType().getShape()); }; /// An extract_slice result type can be inferred, when it is not - /// rank-reduced, from the source type and the static representation of + /// rank-reduced, from the source type and the static representation of /// offsets, sizes and strides. Special sentinels encode the dynamic case. static RankedTensorType inferResultType( ShapedType sourceShapedTensorType, @@ -580,9 +583,12 @@ let assemblyFormat = [{ $source `into` $dest `` - custom($offsets, $static_offsets) - custom($sizes, $static_sizes) - custom($strides, $static_strides) + custom($offsets, $static_offsets, + "ShapedType::kDynamicStrideOrOffset") + custom($sizes, $static_sizes, + "ShapedType::kDynamicSize") + custom($strides, $static_strides, + "ShapedType::kDynamicStrideOrOffset") attr-dict `:` type($source) `into` type($dest) }]; @@ -608,7 +614,7 @@ RankedTensorType getType() { return getResult().getType().cast(); } - + /// The `dest` type is the same as the result type. RankedTensorType getDestType() { return getType(); @@ -962,8 +968,10 @@ let assemblyFormat = [{ $source (`nofold` $nofold^)? - `low` `` custom($low, $static_low) - `high` `` custom($high, $static_high) + `low` `` custom($low, $static_low, + "ShapedType::kDynamicSize") + `high` `` custom($high, $static_high, + "ShapedType::kDynamicSize") $region attr-dict `:` type($source) `to` type($result) }]; @@ -1069,15 +1077,15 @@ // HasParent<"ParallelCombiningOpInterface"> ]> { let summary = [{ - Specify the tensor slice update of a single thread of a parent + Specify the tensor slice update of a single thread of a parent ParallelCombiningOpInterface op. }]; let description = [{ - The `parallel_insert_slice` yields a subset tensor value to its parent + The `parallel_insert_slice` yields a subset tensor value to its parent ParallelCombiningOpInterface. These subset tensor values are aggregated to - in some unspecified order into a full tensor value returned by the parent - parallel iterating op. - The `parallel_insert_slice` is one such op allowed in the + in some unspecified order into a full tensor value returned by the parent + parallel iterating op. + The `parallel_insert_slice` is one such op allowed in the ParallelCombiningOpInterface op. Conflicting writes result in undefined semantics, in that the indices written @@ -1118,12 +1126,12 @@ into a memref.subview op. A parallel_insert_slice operation may additionally specify insertion into a - tensor of higher rank than the source tensor, along dimensions that are + tensor of higher rank than the source tensor, along dimensions that are statically known to be of size 1. This rank-altering behavior is not required by the op semantics: this flexibility allows to progressively drop unit dimensions while lowering between different flavors of ops on that operate on tensors. - The rank-altering behavior of tensor.parallel_insert_slice matches the + The rank-altering behavior of tensor.parallel_insert_slice matches the rank-reducing behavior of tensor.insert_slice and tensor.extract_slice. Verification in the rank-reduced case: @@ -1144,9 +1152,12 @@ ); let assemblyFormat = [{ $source `into` $dest `` - custom($offsets, $static_offsets) - custom($sizes, $static_sizes) - custom($strides, $static_strides) + custom($offsets, $static_offsets, + "ShapedType::kDynamicStrideOrOffset") + custom($sizes, $static_sizes, + "ShapedType::kDynamicSize") + custom($strides, $static_strides, + "ShapedType::kDynamicStrideOrOffset") attr-dict `:` type($source) `into` type($dest) }]; @@ -1194,7 +1205,7 @@ "ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides, CArg<"ArrayRef", "{}">:$attrs)> ]; - + let hasCanonicalizer = 1; let hasFolder = 1; let hasVerifier = 1; diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h @@ -84,73 +84,40 @@ /// Printer hook for custom directive in assemblyFormat. /// -/// custom($values, $integers) +/// custom($values, $integers) /// /// where `values` is of ODS type `Variadic` and `integers` is of ODS -/// type `I64ArrayAttr`. for use in in assemblyFormat. Prints a list with -/// either (1) the static integer value in `integers` if the value is -/// ShapedType::kDynamicStrideOrOffset or (2) the next value otherwise. This -/// allows idiomatic printing of mixed value and integer attributes in a -/// list. E.g. `[%arg0, 7, 42, %arg42]`. -void printOperandsOrIntegersOffsetsOrStridesList(OpAsmPrinter &printer, - Operation *op, - OperandRange values, - ArrayAttr integers); - -/// Printer hook for custom directive in assemblyFormat. -/// -/// custom($values, $integers) -/// -/// where `values` is of ODS type `Variadic` and `integers` is of ODS -/// type `I64ArrayAttr`. for use in in assemblyFormat. Prints a list with -/// either (1) the static integer value in `integers` if the value is -/// ShapedType::kDynamicSize or (2) the next value otherwise. This -/// allows idiomatic printing of mixed value and integer attributes in a -/// list. E.g. `[%arg0, 7, 42, %arg42]`. -void printOperandsOrIntegersSizesList(OpAsmPrinter &printer, Operation *op, - OperandRange values, ArrayAttr integers); - -/// Pasrer hook for custom directive in assemblyFormat. -/// -/// custom($values, $integers) -/// -/// where `values` is of ODS type `Variadic` and `integers` is of ODS -/// type `I64ArrayAttr`. for use in in assemblyFormat. Parse a mixed list with -/// either (1) static integer values or (2) SSA values. Fill `integers` with -/// the integer ArrayAttr, where ShapedType::kDynamicStrideOrOffset encodes the -/// position of SSA values. Add the parsed SSA values to `values` in-order. -// -/// E.g. after parsing "[%arg0, 7, 42, %arg42]": -/// 1. `result` is filled with the i64 ArrayAttr "[`dynVal`, 7, 42, `dynVal`]" -/// 2. `ssa` is filled with "[%arg0, %arg1]". -ParseResult parseOperandsOrIntegersOffsetsOrStridesList( - OpAsmParser &parser, - SmallVectorImpl &values, - ArrayAttr &integers); +/// type `I64ArrayAttr`. Prints a list with either (1) the static integer value +/// in `integers` is `dynVal` or (2) the next value otherwise. This allows +/// idiomatic printing of mixed value and integer attributes in a list. E.g. +/// `[%arg0, 7, 42, %arg42]`. +void printDynamicIndexList(OpAsmPrinter &printer, Operation *op, + OperandRange values, ArrayAttr integers, + int64_t dynVal); /// Pasrer hook for custom directive in assemblyFormat. /// -/// custom($values, $integers) +/// custom($values, $integers) /// /// where `values` is of ODS type `Variadic` and `integers` is of ODS -/// type `I64ArrayAttr`. for use in in assemblyFormat. Parse a mixed list with -/// either (1) static integer values or (2) SSA values. Fill `integers` with -/// the integer ArrayAttr, where ShapedType::kDynamicSize encodes the -/// position of SSA values. Add the parsed SSA values to `values` in-order. +/// type `I64ArrayAttr`. Parse a mixed list with either (1) static integer +/// values or (2) SSA values. Fill `integers` with the integer ArrayAttr, where +/// `dynVal` encodes the position of SSA values. Add the parsed SSA values +/// to `values` in-order. // /// E.g. after parsing "[%arg0, 7, 42, %arg42]": /// 1. `result` is filled with the i64 ArrayAttr "[`dynVal`, 7, 42, `dynVal`]" /// 2. `ssa` is filled with "[%arg0, %arg1]". -ParseResult parseOperandsOrIntegersSizesList( - OpAsmParser &parser, - SmallVectorImpl &values, - ArrayAttr &integers); +ParseResult +parseDynamicIndexList(OpAsmParser &parser, + SmallVectorImpl &values, + ArrayAttr &integers, int64_t dynVal); /// Verify that a the `values` has as many elements as the number of entries in /// `attr` for which `isDynamic` evaluates to true. LogicalResult verifyListOfOperandsOrIntegers( Operation *op, StringRef name, unsigned expectedNumElements, ArrayAttr attr, - ValueRange values, llvm::function_ref isDynamic); + ValueRange values, function_ref isDynamic); } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -987,7 +987,8 @@ auto pdlOperationType = pdl::OperationType::get(parser.getContext()); if (parser.parseOperand(target) || parser.resolveOperand(target, pdlOperationType, result.operands) || - parseOperandsOrIntegersSizesList(parser, dynamicSizes, staticSizes) || + parseDynamicIndexList(parser, dynamicSizes, staticSizes, + ShapedType::kDynamicSize) || parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands) || parser.parseOptionalAttrDict(result.attributes)) return ParseResult::failure(); @@ -1001,8 +1002,8 @@ void TileOp::print(OpAsmPrinter &p) { p << ' ' << getTarget(); - printOperandsOrIntegersSizesList(p, getOperation(), getDynamicSizes(), - getStaticSizes()); + printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes(), + ShapedType::kDynamicSize); p.printOptionalAttrDict((*this)->getAttrs(), {getStaticSizesAttrName()}); } diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp --- a/mlir/lib/Interfaces/ViewLikeInterface.cpp +++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp @@ -70,45 +70,29 @@ return success(); } -template -static void printOperandsOrIntegersListImpl(OpAsmPrinter &p, ValueRange values, - ArrayAttr arrayAttr) { - p << '['; - if (arrayAttr.empty()) { - p << "]"; +void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op, + OperandRange values, ArrayAttr integers, + int64_t dynVal) { + printer << '['; + if (integers.empty()) { + printer << "]"; return; } unsigned idx = 0; - llvm::interleaveComma(arrayAttr, p, [&](Attribute a) { + llvm::interleaveComma(integers, printer, [&](Attribute a) { int64_t val = a.cast().getInt(); if (val == dynVal) - p << values[idx++]; + printer << values[idx++]; else - p << val; + printer << val; }); - p << ']'; + printer << ']'; } -void mlir::printOperandsOrIntegersOffsetsOrStridesList(OpAsmPrinter &p, - Operation *op, - OperandRange values, - ArrayAttr integers) { - return printOperandsOrIntegersListImpl( - p, values, integers); -} - -void mlir::printOperandsOrIntegersSizesList(OpAsmPrinter &p, Operation *op, - OperandRange values, - ArrayAttr integers) { - return printOperandsOrIntegersListImpl(p, values, - integers); -} - -template -static ParseResult parseOperandsOrIntegersImpl( +ParseResult mlir::parseDynamicIndexList( OpAsmParser &parser, SmallVectorImpl &values, - ArrayAttr &integers) { + ArrayAttr &integers, int64_t dynVal) { if (failed(parser.parseLSquare())) return failure(); // 0-D. @@ -142,22 +126,6 @@ return success(); } -ParseResult mlir::parseOperandsOrIntegersOffsetsOrStridesList( - OpAsmParser &parser, - SmallVectorImpl &values, - ArrayAttr &integers) { - return parseOperandsOrIntegersImpl( - parser, values, integers); -} - -ParseResult mlir::parseOperandsOrIntegersSizesList( - OpAsmParser &parser, - SmallVectorImpl &values, - ArrayAttr &integers) { - return parseOperandsOrIntegersImpl(parser, values, - integers); -} - bool mlir::detail::sameOffsetsSizesAndStrides( OffsetSizeAndStrideOpInterface a, OffsetSizeAndStrideOpInterface b, llvm::function_ref cmp) { diff --git a/mlir/test/mlir-tblgen/attr-or-type-format.td b/mlir/test/mlir-tblgen/attr-or-type-format.td --- a/mlir/test/mlir-tblgen/attr-or-type-format.td +++ b/mlir/test/mlir-tblgen/attr-or-type-format.td @@ -555,3 +555,19 @@ let mnemonic = "type_k"; let assemblyFormat = "$a"; } + +// TYPE-LABEL: ::mlir::Type TestNType::parse +// TYPE: parseFoo( +// TYPE-NEXT: _result_a, +// TYPE-NEXT: 1); + +// TYPE-LABEL: void TestNType::print +// TYPE: printFoo( +// TYPE-NEXT: getA(), +// TYPE-NEXT: 1); + +def TypeL : TestType<"TestN"> { + let parameters = (ins "int":$a); + let mnemonic = "type_l"; + let assemblyFormat = [{ custom($a, "1") }]; +} diff --git a/mlir/test/mlir-tblgen/op-format-invalid.td b/mlir/test/mlir-tblgen/op-format-invalid.td --- a/mlir/test/mlir-tblgen/op-format-invalid.td +++ b/mlir/test/mlir-tblgen/op-format-invalid.td @@ -403,6 +403,13 @@ ($arg^):(`test`) }]>, Arguments<(ins Variadic:$arg)>; +//===----------------------------------------------------------------------===// +// Strings +//===----------------------------------------------------------------------===// + +// CHECK: error: strings may only be used as 'custom' directive arguments +def StringInvalidA : TestFormat_Op<[{ "foo" }]>; + //===----------------------------------------------------------------------===// // Variables //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td --- a/mlir/test/mlir-tblgen/op-format-spec.td +++ b/mlir/test/mlir-tblgen/op-format-spec.td @@ -135,6 +135,13 @@ (` ` `` $arg^)? attr-dict }]>, Arguments<(ins Optional:$arg)>; +//===----------------------------------------------------------------------===// +// Strings +//===----------------------------------------------------------------------===// + +// CHECK-NOT: error +def StringInvalidA : TestFormat_Op<[{ custom("foo") attr-dict }]>; + //===----------------------------------------------------------------------===// // Variables //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-tblgen/op-format.td b/mlir/test/mlir-tblgen/op-format.td new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-tblgen/op-format.td @@ -0,0 +1,42 @@ +// RUN: mlir-tblgen -gen-op-defs -I %S/../../include %s | FileCheck %s + +include "mlir/IR/OpBase.td" + +def TestDialect : Dialect { + let name = "test"; +} +class TestFormat_Op traits = []> + : Op { + let assemblyFormat = fmt; +} + +//===----------------------------------------------------------------------===// +// Directives +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// custom + +// CHECK-LABEL: CustomStringLiteralA::parse +// CHECK: parseFoo({{.*}}, parser.getBuilder().getI1Type()) +// CHECK-LABEL: CustomStringLiteralA::print +// CHECK: printFoo({{.*}}, parser.getBuilder().getI1Type()) +def CustomStringLiteralA : TestFormat_Op<[{ + custom("$_builder.getI1Type()") attr-dict +}]>; + +// CHECK-LABEL: CustomStringLiteralB::parse +// CHECK: parseFoo({{.*}}, IndexType::get(parser.getContext())) +// CHECK-LABEL: CustomStringLiteralB::print +// CHECK: printFoo({{.*}}, IndexType::get(parser.getContext())) +def CustomStringLiteralB : TestFormat_Op<[{ + custom("IndexType::get($_ctxt)") attr-dict +}]>; + +// CHECK-LABEL: CustomStringLiteralC::parse +// CHECK: parseFoo({{.*}}, parser.getBuilder().getStringAttr("foo")) +// CHECK-LABEL: CustomStringLiteralC::print +// CHECK: printFoo({{.*}}, parser.getBuilder().getStringAttr("foo")) +def CustomStringLiteralC : TestFormat_Op<[{ + custom("$_builder.getStringAttr(\"foo\")") attr-dict +}]>; diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp @@ -629,14 +629,12 @@ os.indent(); for (FormatElement *arg : el->getArguments()) { os << ",\n"; - FormatElement *param; - if (auto *ref = dyn_cast(arg)) { - os << "*"; - param = ref->getArg(); - } else { - param = arg; - } - os << "_result_" << cast(param)->getName(); + if (auto *param = dyn_cast(arg)) + os << "_result_" << param->getName(); + else if (auto *ref = dyn_cast(arg)) + os << "*_result_" << cast(ref->getArg())->getName(); + else + os << tgfmt(cast(arg)->getValue(), &ctx); } os.unindent() << ");\n"; os << "if (::mlir::failed(odsCustomResult)) return {};\n"; @@ -845,11 +843,15 @@ os << tgfmt("print$0($_printer", &ctx, el->getName()); os.indent(); for (FormatElement *arg : el->getArguments()) { - FormatElement *param = arg; - if (auto *ref = dyn_cast(arg)) - param = ref->getArg(); - os << ",\n" - << cast(param)->getParam().getAccessorName() << "()"; + os << ",\n"; + if (auto *param = dyn_cast(arg)) { + os << param->getParam().getAccessorName() << "()"; + } else if (auto *ref = dyn_cast(arg)) { + os << cast(ref->getArg())->getParam().getAccessorName() + << "()"; + } else { + os << tgfmt(cast(arg)->getValue(), &ctx); + } } os.unindent() << ");\n"; } diff --git a/mlir/tools/mlir-tblgen/FormatGen.h b/mlir/tools/mlir-tblgen/FormatGen.h --- a/mlir/tools/mlir-tblgen/FormatGen.h +++ b/mlir/tools/mlir-tblgen/FormatGen.h @@ -78,6 +78,7 @@ identifier, literal, variable, + string, }; FormatToken(Kind kind, StringRef spelling) : kind(kind), spelling(spelling) {} @@ -130,10 +131,11 @@ /// Return the next character in the stream. int getNextChar(); - /// Lex an identifier, literal, or variable. + /// Lex an identifier, literal, variable, or string. FormatToken lexIdentifier(const char *tokStart); FormatToken lexLiteral(const char *tokStart); FormatToken lexVariable(const char *tokStart); + FormatToken lexString(const char *tokStart); /// Create a token with the current pointer and a start pointer. FormatToken formToken(FormatToken::Kind kind, const char *tokStart) { @@ -163,7 +165,7 @@ virtual ~FormatElement(); // The top-level kinds of format elements. - enum Kind { Literal, Variable, Whitespace, Directive, Optional }; + enum Kind { Literal, String, Variable, Whitespace, Directive, Optional }; /// Support LLVM-style RTTI. static bool classof(const FormatElement *el) { return true; } @@ -212,6 +214,20 @@ StringRef spelling; }; +/// This class represents a raw string that can contain arbitrary C++ code. +class StringElement : public FormatElementBase { +public: + /// Create a string element with the given contents. + explicit StringElement(std::string value) : value(std::move(value)) {} + + /// Get the value of the string element. + StringRef getValue() const { return value; } + +private: + /// The contents of the string. + std::string value; +}; + /// This class represents a variable element. A variable refers to some part of /// the object being parsed, e.g. an attribute or operand on an operation or a /// parameter on an attribute. @@ -447,6 +463,8 @@ FailureOr parseElement(Context ctx); /// Parse a literal. FailureOr parseLiteral(Context ctx); + /// Parse a string. + FailureOr parseString(Context ctx); /// Parse a variable. FailureOr parseVariable(Context ctx); /// Parse a directive. diff --git a/mlir/tools/mlir-tblgen/FormatGen.cpp b/mlir/tools/mlir-tblgen/FormatGen.cpp --- a/mlir/tools/mlir-tblgen/FormatGen.cpp +++ b/mlir/tools/mlir-tblgen/FormatGen.cpp @@ -129,6 +129,8 @@ return lexLiteral(tokStart); case '$': return lexVariable(tokStart); + case '"': + return lexString(tokStart); } } @@ -153,6 +155,17 @@ return formToken(FormatToken::variable, tokStart); } +FormatToken FormatLexer::lexString(const char *tokStart) { + // Lex until another quote, respecting escapes. + bool escape = false; + while (const char curChar = *curPtr++) { + if (!escape && curChar == '"') + return formToken(FormatToken::string, tokStart); + escape = curChar == '\\'; + } + return emitError(curPtr - 1, "unexpected end of file in string"); +} + FormatToken FormatLexer::lexIdentifier(const char *tokStart) { // Match the rest of the identifier regex: [0-9a-zA-Z_\-]* while (isalnum(*curPtr) || *curPtr == '_' || *curPtr == '-') @@ -212,6 +225,8 @@ FailureOr FormatParser::parseElement(Context ctx) { if (curToken.is(FormatToken::literal)) return parseLiteral(ctx); + if (curToken.is(FormatToken::string)) + return parseString(ctx); if (curToken.is(FormatToken::variable)) return parseVariable(ctx); if (curToken.isKeyword()) @@ -253,6 +268,28 @@ return create(value); } +FailureOr FormatParser::parseString(Context ctx) { + FormatToken tok = curToken; + SMLoc loc = tok.getLoc(); + consumeToken(); + + if (ctx != CustomDirectiveContext) { + return emitError( + loc, "strings may only be used as 'custom' directive arguments"); + } + // Escape the string. + std::string value; + StringRef contents = tok.getSpelling().drop_front().drop_back(); + value.reserve(contents.size()); + bool escape = false; + for (char c : contents) { + escape = c == '\\'; + if (!escape) + value.push_back(c); + } + return create(std::move(value)); +} + FailureOr FormatParser::parseVariable(Context ctx) { FormatToken tok = curToken; SMLoc loc = tok.getLoc(); diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -916,6 +916,13 @@ body << llvm::formatv("{0}Type", listName); else body << formatv("{0}RawTypes[0]", listName); + + } else if (auto *string = dyn_cast(param)) { + FmtContext ctx; + ctx.withBuilder("parser.getBuilder()"); + ctx.addSubst("_ctxt", "parser.getContext()"); + body << tgfmt(string->getValue(), &ctx); + } else { llvm_unreachable("unknown custom directive parameter"); } @@ -1715,6 +1722,13 @@ body << llvm::formatv("({0}() ? {0}().getType() : Type())", name); else body << name << "().getType()"; + + } else if (auto *string = dyn_cast(element)) { + FmtContext ctx; + ctx.withBuilder("parser.getBuilder()"); + ctx.addSubst("_ctxt", "parser.getContext()"); + body << tgfmt(string->getValue(), &ctx); + } else { llvm_unreachable("unknown custom directive parameter"); } @@ -2826,8 +2840,9 @@ LogicalResult OpFormatParser::verifyCustomDirectiveArguments( SMLoc loc, ArrayRef arguments) { for (FormatElement *argument : arguments) { - if (!isa(argument)) { + if (!isa(argument)) { // TODO: FormatElement should have location info attached. return emitError(loc, "only variables and types may be used as " "parameters to a custom directive");