diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md --- a/mlir/docs/OpDefinitions.md +++ b/mlir/docs/OpDefinitions.md @@ -746,12 +746,12 @@ identifier used as a suffix to these two calls, i.e., `custom(...)` would result in calls to `parseMyDirective` and `printMyDirective` wihtin the parser and printer respectively. `Params` may be any combination of variables -(i.e. Attribute, Operand, Successor, etc.) and type directives. The type -directives must refer to a variable, but that variable need not also be a -parameter to the custom directive. +(i.e. Attribute, Operand, Successor, etc.), type directives, and `attr-dict`. +The type directives must refer to a variable, but that variable need not also +be a parameter to the custom directive. -The arguments to the `parse` method is firstly a reference to the -`OpAsmParser`(`OpAsmParser &`), and secondly a set of output parameters +The arguments to the `parse` method are firstly a reference to +the `OpAsmParser`(`OpAsmParser &`), and secondly a set of output parameters corresponding to the parameters specified in the format. The mapping of declarative parameter to `parse` method argument is detailed below: @@ -776,12 +776,14 @@ - Single: `Type` - Optional: `Type` - Variadic: `const SmallVectorImpl &` +* `attr-dict` Directive: `NamedAttrList &` When a variable is optional, the value should only be specified if the variable is present. Otherwise, the value should remain `None` or null. -The arguments to the `print` method is firstly a reference to the -`OpAsmPrinter`(`OpAsmPrinter &`), and secondly a set of output parameters +The arguments to the `print` method is firstly a reference to +the `OpAsmPrinter`(`OpAsmPrinter &`), second the op (e.g. `FooOp op` which +can be `Operation *op` alternatively), and finally a set of output parameters corresponding to the parameters specified in the format. The mapping of declarative parameter to `print` method argument is detailed below: @@ -806,6 +808,7 @@ - Single: `Type` - Optional: `Type` - Variadic: `TypeRange` +* `attr-dict` Directive: `const MutableDictionaryAttr&` When a variable is optional, the provided value may be null. diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp --- a/mlir/lib/Dialect/Async/IR/Async.cpp +++ b/mlir/lib/Dialect/Async/IR/Async.cpp @@ -276,8 +276,8 @@ return success(); } -static void printAwaitResultType(OpAsmPrinter &p, Type operandType, - Type resultType) { +static void printAwaitResultType(OpAsmPrinter &p, Operation *op, + Type operandType, Type resultType) { p << operandType; } diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -831,7 +831,8 @@ OpAsmParser::Delimiter::OptionalSquare); } -static void printAsyncDependencies(OpAsmPrinter &printer, Type asyncTokenType, +static void printAsyncDependencies(OpAsmPrinter &printer, Operation *op, + Type asyncTokenType, OperandRange asyncDependencies) { if (asyncTokenType) printer << "async "; diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -385,19 +385,24 @@ return success(); } +static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser, + NamedAttrList &attrs) { + return parser.parseOptionalAttrDict(attrs); +} + //===----------------------------------------------------------------------===// // Printing -static void printCustomDirectiveOperands(OpAsmPrinter &printer, Value operand, - Value optOperand, +static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *, + Value operand, Value optOperand, OperandRange varOperands) { printer << operand; if (optOperand) printer << ", " << optOperand; printer << " -> (" << varOperands << ")"; } -static void printCustomDirectiveResults(OpAsmPrinter &printer, Type operandType, - Type optOperandType, +static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *, + Type operandType, Type optOperandType, TypeRange varOperandTypes) { printer << " : " << operandType; if (optOperandType) @@ -405,23 +410,23 @@ printer << " -> (" << varOperandTypes << ")"; } static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer, - Type operandType, + Operation *op, Type operandType, Type optOperandType, TypeRange varOperandTypes) { printer << " type_refs_capture "; - printCustomDirectiveResults(printer, operandType, optOperandType, + printCustomDirectiveResults(printer, op, operandType, optOperandType, varOperandTypes); } -static void -printCustomDirectiveOperandsAndTypes(OpAsmPrinter &printer, Value operand, - Value optOperand, OperandRange varOperands, - Type operandType, Type optOperandType, - TypeRange varOperandTypes) { - printCustomDirectiveOperands(printer, operand, optOperand, varOperands); - printCustomDirectiveResults(printer, operandType, optOperandType, +static void printCustomDirectiveOperandsAndTypes( + OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand, + OperandRange varOperands, Type operandType, Type optOperandType, + TypeRange varOperandTypes) { + printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands); + printCustomDirectiveResults(printer, op, operandType, optOperandType, varOperandTypes); } -static void printCustomDirectiveRegions(OpAsmPrinter &printer, Region ®ion, +static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *, + Region ®ion, MutableArrayRef varRegions) { printer.printRegion(region); if (!varRegions.empty()) { @@ -430,14 +435,14 @@ printer.printRegion(region); } } -static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, +static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *, Block *successor, SuccessorRange varSuccessors) { printer << successor; if (!varSuccessors.empty()) printer << ", " << varSuccessors.front(); } -static void printCustomDirectiveAttributes(OpAsmPrinter &printer, +static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *, Attribute attribute, Attribute optAttribute) { printer << attribute; @@ -445,6 +450,10 @@ printer << ", " << optAttribute; } +static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op, + MutableDictionaryAttr attrs) { + printer.printOptionalAttrDict(attrs.getAttrs()); +} //===----------------------------------------------------------------------===// // Test IsolatedRegionOp - parse passthrough region arguments. //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1638,6 +1638,14 @@ }]; } +def FormatCustomDirectiveAttrDict + : TEST_Op<"format_custom_directive_attrdict"> { + let arguments = (ins I64Attr:$attr, OptionalAttr:$optAttr); + let assemblyFormat = [{ + custom( attr-dict ) + }]; +} + //===----------------------------------------------------------------------===// // AllTypesMatch type inference diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -845,7 +845,8 @@ body << ", "; if (auto *attr = dyn_cast(¶m)) { body << attr->getVar()->name << "Attr"; - + } else if (isa(¶m)) { + body << "result.attributes"; } else if (auto *operand = dyn_cast(¶m)) { StringRef name = operand->getVar()->name; ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar()); @@ -1473,12 +1474,18 @@ /// the previous element was a punctuation literal. static void genCustomDirectivePrinter(CustomDirective *customDir, OpMethodBody &body) { - body << " print" << customDir->getName() << "(p"; + body << " print" << customDir->getName() << "(p, *this"; for (Element ¶m : customDir->getArguments()) { body << ", "; if (auto *attr = dyn_cast(¶m)) { body << attr->getVar()->name << "Attr()"; + } else if (isa(¶m)) { + // Enforce the const-ness since getMutableAttrDict() returns a reference + // into the Operations `attr` member. + body << "(const " + "MutableDictionaryAttr&)getOperation()->getMutableAttrDict()"; + } else if (auto *operand = dyn_cast(¶m)) { body << operand->getVar()->name << "()"; @@ -2735,8 +2742,9 @@ return ::mlir::failure(); // Verify that the element can be placed within a custom directive. - if (!isa(parameters.back().get())) { + if (!isa(parameters.back().get())) { return emitError(childLoc, "only variables and types may be used as " "parameters to a custom directive"); }