diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -37,6 +37,7 @@ /// Print implementations for various things an operation contains. virtual void printOperand(Value value) = 0; + virtual void printOperand(Value value, raw_ostream &os) = 0; /// Print a comma separated list of operands. template @@ -245,6 +246,20 @@ return success(); } + /// Return the name of the specified result in the specified syntax, as well + /// as the subelement in the name. It returns an empty string an ~0U for + /// invalid result numbers. For example, in this operation: + /// + /// %x, %y:2, %z = foo.op + /// + /// getResultName(0) == {"x", 0 } + /// getResultName(1) == {"y", 0 } + /// getResultName(2) == {"y", 1 } + /// getResultName(3) == {"z", 0 } + /// getResultName(4) == {"", ~0U } + virtual std::pair + getResultName(unsigned resultNo) const = 0; + /// Return the location of the original name token. virtual llvm::SMLoc getNameLoc() const = 0; diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1963,7 +1963,8 @@ bool printBlockTerminator = true); /// Print the ID of the given value, optionally with its result number. - void printValueID(Value value, bool printResultNo = true) const; + void printValueID(Value value, bool printResultNo = true, + raw_ostream *streamOverride = nullptr) const; //===--------------------------------------------------------------------===// // OpAsmPrinter methods @@ -1988,6 +1989,9 @@ /// Print the ID for the given value. void printOperand(Value value) override { printValueID(value); } + void printOperand(Value value, raw_ostream &os) override { + printValueID(value, /*printResultNo=*/true, &os); + }; /// Print an optional attribute dictionary with a given set of elided values. void printOptionalAttrDict(ArrayRef attrs, @@ -2195,8 +2199,11 @@ currentIndent -= indentWidth; } -void OperationPrinter::printValueID(Value value, bool printResultNo) const { - state->getSSANameState().printValueID(value, printResultNo, os); +void OperationPrinter::printValueID(Value value, bool printResultNo, + raw_ostream *streamOverride) const { + if (!streamOverride) + streamOverride = &os; + state->getSSANameState().printValueID(value, printResultNo, *streamOverride); } void OperationPrinter::printSuccessor(Block *successor) { diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -3322,8 +3322,13 @@ Operation *parseGenericOperation(Block *insertBlock, Block::iterator insertPt); + /// This is the structure of a result specifier in the assembly syntax, + /// including the name, number of results, and location. + typedef std::tuple ResultRecord; + /// Parse an operation instance that is in the op-defined custom form. - Operation *parseCustomOperation(); + /// resultInfo specifies information about the "%name =" specifiers. + Operation *parseCustomOperation(ArrayRef resultInfo); //===--------------------------------------------------------------------===// // Region Parsing @@ -3728,7 +3733,7 @@ /// ParseResult OperationParser::parseOperation() { auto loc = getToken().getLoc(); - SmallVector, 1> resultIDs; + SmallVector resultIDs; size_t numExpectedResults = 0; if (getToken().is(Token::percent_identifier)) { // Parse the group of result ids. @@ -3769,7 +3774,7 @@ Operation *op; if (getToken().is(Token::bare_identifier) || getToken().isKeyword()) - op = parseCustomOperation(); + op = parseCustomOperation(resultIDs); else if (getToken().is(Token::string)) op = parseGenericOperation(); else @@ -3790,7 +3795,7 @@ // Add definitions for each of the result groups. unsigned opResI = 0; - for (std::tuple &resIt : resultIDs) { + for (ResultRecord &resIt : resultIDs) { for (unsigned subRes : llvm::seq(0, std::get<1>(resIt))) { if (addDefinition({std::get<0>(resIt), subRes, std::get<2>(resIt)}, op->getResult(opResI++))) @@ -3955,9 +3960,12 @@ namespace { class CustomOpAsmParser : public OpAsmParser { public: - CustomOpAsmParser(SMLoc nameLoc, const AbstractOperation *opDefinition, + CustomOpAsmParser(SMLoc nameLoc, + ArrayRef resultIDs, + const AbstractOperation *opDefinition, OperationParser &parser) - : nameLoc(nameLoc), opDefinition(opDefinition), parser(parser) {} + : nameLoc(nameLoc), resultIDs(resultIDs), opDefinition(opDefinition), + parser(parser) {} /// Parse an instance of the operation described by 'opDefinition' into the /// provided operation state. @@ -3992,6 +4000,31 @@ Builder &getBuilder() const override { return parser.builder; } + /// Return the name of the specified result in the specified syntax, as well + /// as the subelement in the name. For example, in this operation: + /// + /// %x, %y:2, %z = foo.op + /// + /// getResultName(0) == {"x", 0 } + /// getResultName(1) == {"y", 0 } + /// getResultName(2) == {"y", 1 } + /// getResultName(3) == {"z", 0 } + std::pair + getResultName(unsigned resultNo) const override { + // Scan for the resultID that contains this result number. + for (unsigned nameID = 0, e = resultIDs.size(); nameID != e; ++nameID) { + auto &entry = resultIDs[nameID]; + if (resultNo < std::get<1>(entry)) { + // Don't pass on the leading %. + auto name = std::get<0>(entry).drop_front(); + return {name, resultNo}; + } + } + + // Invalid result number. + return {"", ~0U}; + } + llvm::SMLoc getNameLoc() const override { return nameLoc; } //===--------------------------------------------------------------------===// @@ -4500,6 +4533,9 @@ /// The source location of the operation name. SMLoc nameLoc; + /// Information about the result name specifiers. + ArrayRef resultIDs; + /// The abstract information of the operation. const AbstractOperation *opDefinition; @@ -4511,7 +4547,8 @@ }; } // end anonymous namespace. -Operation *OperationParser::parseCustomOperation() { +Operation * +OperationParser::parseCustomOperation(ArrayRef resultIDs) { auto opLoc = getToken().getLoc(); auto opName = getTokenSpelling(); @@ -4544,7 +4581,7 @@ // Have the op implementation take a crack and parsing this. OperationState opState(srcLocation, opDefinition->name); CleanupOpStateRegions guard{opState}; - CustomOpAsmParser opAsmParser(opLoc, opDefinition, *this); + CustomOpAsmParser opAsmParser(opLoc, resultIDs, opDefinition, *this); if (opAsmParser.parseOperation(opState)) return nullptr; diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -1185,3 +1185,32 @@ // CHECK: return %[[FIRST]], %[[MIDDLE]]#0, %[[MIDDLE]]#1, %[[LAST]], %[[FIRST_2]], %[[LAST_2]] return %0, %1#0, %1#1, %2, %3, %4, %5 : i32, i32, i32, i32, i32, i32, i32 } + + +// CHECK-LABEL: func @pretty_names + +// This tests the behavior +func @pretty_names() { + // Simple case, should parse and print as %x being an implied 'name' + // attribute. + %x = test.string_attr_pretty_name + // CHECK: %x = test.string_attr_pretty_name + // CHECK-NOT: attributes + + // This specifies an explicit name, which should override the result. + %YY = test.string_attr_pretty_name attributes { name = "y" } + // CHECK: %y = test.string_attr_pretty_name + // CHECK-NOT: attributes + + // Conflicts with the 'y' name, so need an explicit attribute. + %0 = "test.string_attr_pretty_name"() { name = "y"} : () -> i32 + // CHECK: %y_0 = test.string_attr_pretty_name attributes {name = "y"} + + // Name contains a space. + %1 = "test.string_attr_pretty_name"() { name = "space name"} : () -> i32 + // CHECK: %space_name = test.string_attr_pretty_name attributes {name = "space name"} + + "unknown.use"(%x, %YY, %0, %1) : (i32, i32, i32, i32) -> () + return +} + diff --git a/mlir/test/lib/TestDialect/TestDialect.cpp b/mlir/test/lib/TestDialect/TestDialect.cpp --- a/mlir/test/lib/TestDialect/TestDialect.cpp +++ b/mlir/test/lib/TestDialect/TestDialect.cpp @@ -391,6 +391,72 @@ } } +//===----------------------------------------------------------------------===// +// StringAttrPrettyNameOp +//===----------------------------------------------------------------------===// + +// This op has fancy handling of its SSA result name. + +static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser, + OperationState &result) { + result.addTypes(parser.getBuilder().getIntegerType(32)); + + if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) + return failure(); + + // If the attribute dictionary contains no 'name' attribute, infer it from + // the SSA name (if specified). + bool hadName = false; + for (auto &attr : result.attributes) { + if (attr.first.is("name")) { + hadName = true; + break; + } + } + + // If there was no name specified, check to see if there was a useful name + // specified in the asm file. + if (!hadName) { + auto resultName = parser.getResultName(0); + if (resultName.second == 0 && !resultName.first.empty() && + !isdigit(resultName.first[0])) { + auto *context = result.getContext(); + auto nameAttr = StringAttr::get(resultName.first, context); + result.attributes.push_back({Identifier::get("name", context), nameAttr}); + } + } + + return success(); +} + +static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) { + p << "test.string_attr_pretty_name"; + + // Note that we only need to print the "name" attribute if the asmprinter + // result name disagrees with it. This can happen in strange cases, e.g. + // when there are conflicts. + ArrayRef elidedAttrs; + + SmallString<32> resultNameStr; + llvm::raw_svector_ostream tmpStream(resultNameStr); + p.printOperand(op.getResult(), tmpStream); + + // If the name is the same as we would otherwise use, then we're good! + if (tmpStream.str().drop_front() == op.name()) + elidedAttrs = {"name"}; + + p.printOptionalAttrDictWithKeyword(op.getAttrs(), elidedAttrs); +} + +// We set the SSA name in the asm syntax to the contents of the name attribute. +void StringAttrPrettyNameOp::getAsmResultNames( + function_ref setNameFn) { + + auto value = name(); + if (!value.empty()) + setNameFn(getResult(), value); +} + //===----------------------------------------------------------------------===// // Dialect Registration //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td --- a/mlir/test/lib/TestDialect/TestOps.td +++ b/mlir/test/lib/TestDialect/TestOps.td @@ -493,6 +493,18 @@ ); } +// This is used to test encoding of a string attribute into an SSA name of a +// pretty printed value name. +def StringAttrPrettyNameOp + : TEST_Op<"string_attr_pretty_name", + [DeclareOpInterfaceMethods]> { + let arguments = (ins StrAttr:$name); + let results = (outs I32:$r); + + let printer = [{ return ::print(p, *this); }]; + let parser = [{ return ::parse$cppClass(parser, result); }]; +} + //===----------------------------------------------------------------------===// // Test Patterns //===----------------------------------------------------------------------===//