diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -37,6 +37,7 @@ /// Print implementations for various things an operation contains. virtual void printOperand(Value value) = 0; + virtual void printOperand(Value value, raw_ostream &os) = 0; /// Print a comma separated list of operands. template @@ -245,6 +246,24 @@ return success(); } + /// Return the name of the specified result in the specified syntax, as well + /// as the sub-element in the name. It returns an empty string and ~0U for + /// invalid result numbers. For example, in this operation: + /// + /// %x, %y:2, %z = foo.op + /// + /// getResultName(0) == {"x", 0 } + /// getResultName(1) == {"y", 0 } + /// getResultName(2) == {"y", 1 } + /// getResultName(3) == {"z", 0 } + /// getResultName(4) == {"", ~0U } + virtual std::pair + getResultName(unsigned resultNo) const = 0; + + /// Return the number of declared SSA results. This returns 4 for the foo.op + /// example in the comment for `getResultName`. + virtual size_t getNumResults() const = 0; + /// Return the location of the original name token. virtual llvm::SMLoc getNameLoc() const = 0; diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -765,10 +765,10 @@ static bool isPunct(char c) { return c == '$' || c == '.' || c == '_' || c == '-'; } - + StringRef SSANameState::uniqueValueName(StringRef name) { assert(!name.empty() && "Shouldn't have an empty name here"); - + // Check to see if this name is valid. If it starts with a digit, then it // could conflict with the autogenerated numeric ID's (we unique them in a // different map), so add an underscore prefix to avoid problems. @@ -777,13 +777,13 @@ tmpName += name; return uniqueValueName(tmpName); } - + // Check to see if the name consists of all-valid identifiers. If not, we // need to escape them. for (char ch : name) { if (isalpha(ch) || isPunct(ch) || isdigit(ch)) continue; - + SmallString<16> tmpName; for (char ch : name) { if (isalpha(ch) || isPunct(ch) || isdigit(ch)) @@ -796,7 +796,7 @@ } return uniqueValueName(tmpName); } - + // Check to see if this name is already unique. if (!usedNames.count(name)) { name = name.copy(usedNameAllocator); @@ -1963,7 +1963,8 @@ bool printBlockTerminator = true); /// Print the ID of the given value, optionally with its result number. - void printValueID(Value value, bool printResultNo = true) const; + void printValueID(Value value, bool printResultNo = true, + raw_ostream *streamOverride = nullptr) const; //===--------------------------------------------------------------------===// // OpAsmPrinter methods @@ -1988,6 +1989,9 @@ /// Print the ID for the given value. void printOperand(Value value) override { printValueID(value); } + void printOperand(Value value, raw_ostream &os) override { + printValueID(value, /*printResultNo=*/true, &os); + } /// Print an optional attribute dictionary with a given set of elided values. void printOptionalAttrDict(ArrayRef attrs, @@ -2195,8 +2199,10 @@ currentIndent -= indentWidth; } -void OperationPrinter::printValueID(Value value, bool printResultNo) const { - state->getSSANameState().printValueID(value, printResultNo, os); +void OperationPrinter::printValueID(Value value, bool printResultNo, + raw_ostream *streamOverride) const { + state->getSSANameState().printValueID(value, printResultNo, + streamOverride ? *streamOverride : os); } void OperationPrinter::printSuccessor(Block *successor) { diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -3322,8 +3322,13 @@ Operation *parseGenericOperation(Block *insertBlock, Block::iterator insertPt); + /// This is the structure of a result specifier in the assembly syntax, + /// including the name, number of results, and location. + typedef std::tuple ResultRecord; + /// Parse an operation instance that is in the op-defined custom form. - Operation *parseCustomOperation(); + /// resultInfo specifies information about the "%name =" specifiers. + Operation *parseCustomOperation(ArrayRef resultInfo); //===--------------------------------------------------------------------===// // Region Parsing @@ -3728,7 +3733,7 @@ /// ParseResult OperationParser::parseOperation() { auto loc = getToken().getLoc(); - SmallVector, 1> resultIDs; + SmallVector resultIDs; size_t numExpectedResults = 0; if (getToken().is(Token::percent_identifier)) { // Parse the group of result ids. @@ -3769,7 +3774,7 @@ Operation *op; if (getToken().is(Token::bare_identifier) || getToken().isKeyword()) - op = parseCustomOperation(); + op = parseCustomOperation(resultIDs); else if (getToken().is(Token::string)) op = parseGenericOperation(); else @@ -3790,7 +3795,7 @@ // Add definitions for each of the result groups. unsigned opResI = 0; - for (std::tuple &resIt : resultIDs) { + for (ResultRecord &resIt : resultIDs) { for (unsigned subRes : llvm::seq(0, std::get<1>(resIt))) { if (addDefinition({std::get<0>(resIt), subRes, std::get<2>(resIt)}, op->getResult(opResI++))) @@ -3955,9 +3960,12 @@ namespace { class CustomOpAsmParser : public OpAsmParser { public: - CustomOpAsmParser(SMLoc nameLoc, const AbstractOperation *opDefinition, + CustomOpAsmParser(SMLoc nameLoc, + ArrayRef resultIDs, + const AbstractOperation *opDefinition, OperationParser &parser) - : nameLoc(nameLoc), opDefinition(opDefinition), parser(parser) {} + : nameLoc(nameLoc), resultIDs(resultIDs), opDefinition(opDefinition), + parser(parser) {} /// Parse an instance of the operation described by 'opDefinition' into the /// provided operation state. @@ -3992,6 +4000,41 @@ Builder &getBuilder() const override { return parser.builder; } + /// Return the name of the specified result in the specified syntax, as well + /// as the subelement in the name. For example, in this operation: + /// + /// %x, %y:2, %z = foo.op + /// + /// getResultName(0) == {"x", 0 } + /// getResultName(1) == {"y", 0 } + /// getResultName(2) == {"y", 1 } + /// getResultName(3) == {"z", 0 } + std::pair + getResultName(unsigned resultNo) const override { + // Scan for the resultID that contains this result number. + for (unsigned nameID = 0, e = resultIDs.size(); nameID != e; ++nameID) { + const auto &entry = resultIDs[nameID]; + if (resultNo < std::get<1>(entry)) { + // Don't pass on the leading %. + StringRef name = std::get<0>(entry).drop_front(); + return {name, resultNo}; + } + resultNo -= std::get<1>(entry); + } + + // Invalid result number. + return {"", ~0U}; + } + + /// Return the number of declared SSA results. This returns 4 for the foo.op + /// example in the comment for getResultName. + size_t getNumResults() const override { + size_t count = 0; + for (auto &entry : resultIDs) + count += std::get<1>(entry); + return count; + } + llvm::SMLoc getNameLoc() const override { return nameLoc; } //===--------------------------------------------------------------------===// @@ -4500,6 +4543,9 @@ /// The source location of the operation name. SMLoc nameLoc; + /// Information about the result name specifiers. + ArrayRef resultIDs; + /// The abstract information of the operation. const AbstractOperation *opDefinition; @@ -4511,7 +4557,8 @@ }; } // end anonymous namespace. -Operation *OperationParser::parseCustomOperation() { +Operation * +OperationParser::parseCustomOperation(ArrayRef resultIDs) { auto opLoc = getToken().getLoc(); auto opName = getTokenSpelling(); @@ -4544,7 +4591,7 @@ // Have the op implementation take a crack and parsing this. OperationState opState(srcLocation, opDefinition->name); CleanupOpStateRegions guard{opState}; - CustomOpAsmParser opAsmParser(opLoc, opDefinition, *this); + CustomOpAsmParser opAsmParser(opLoc, resultIDs, opDefinition, *this); if (opAsmParser.parseOperation(opState)) return nullptr; diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -1185,3 +1185,43 @@ // CHECK: return %[[FIRST]], %[[MIDDLE]]#0, %[[MIDDLE]]#1, %[[LAST]], %[[FIRST_2]], %[[LAST_2]] return %0, %1#0, %1#1, %2, %3, %4, %5 : i32, i32, i32, i32, i32, i32, i32 } + + +// CHECK-LABEL: func @pretty_names + +// This tests the behavior +func @pretty_names() { + // Simple case, should parse and print as %x being an implied 'name' + // attribute. + %x = test.string_attr_pretty_name + // CHECK: %x = test.string_attr_pretty_name + // CHECK-NOT: attributes + + // This specifies an explicit name, which should override the result. + %YY = test.string_attr_pretty_name attributes { names = ["y"] } + // CHECK: %y = test.string_attr_pretty_name + // CHECK-NOT: attributes + + // Conflicts with the 'y' name, so need an explicit attribute. + %0 = "test.string_attr_pretty_name"() { names = ["y"]} : () -> i32 + // CHECK: %y_0 = test.string_attr_pretty_name attributes {names = ["y"]} + + // Name contains a space. + %1 = "test.string_attr_pretty_name"() { names = ["space name"]} : () -> i32 + // CHECK: %space_name = test.string_attr_pretty_name attributes {names = ["space name"]} + + "unknown.use"(%x, %YY, %0, %1) : (i32, i32, i32, i32) -> () + + // Multi-result support. + + %a, %b, %c = test.string_attr_pretty_name + // CHECK: %a, %b, %c = test.string_attr_pretty_name + // CHECK-NOT: attributes + + %q:3, %r = test.string_attr_pretty_name + // CHECK: %q, %q_1, %q_2, %r = test.string_attr_pretty_name attributes {names = ["q", "q", "q", "r"]} + + // CHECK: return + return +} + diff --git a/mlir/test/lib/TestDialect/TestDialect.cpp b/mlir/test/lib/TestDialect/TestDialect.cpp --- a/mlir/test/lib/TestDialect/TestDialect.cpp +++ b/mlir/test/lib/TestDialect/TestDialect.cpp @@ -392,6 +392,87 @@ } //===----------------------------------------------------------------------===// +// StringAttrPrettyNameOp +//===----------------------------------------------------------------------===// + +// This op has fancy handling of its SSA result name. +static ParseResult parseStringAttrPrettyNameOp(OpAsmParser &parser, + OperationState &result) { + // Add the result types. + for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) + result.addTypes(parser.getBuilder().getIntegerType(32)); + + if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) + return failure(); + + // If the attribute dictionary contains no 'names' attribute, infer it from + // the SSA name (if specified). + bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { + return attr.first.is("names"); + }); + + // If there was no name specified, check to see if there was a useful name + // specified in the asm file. + if (hadNames || parser.getNumResults() == 0) + return success(); + + SmallVector names; + auto *context = result.getContext(); + + for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) { + auto resultName = parser.getResultName(i); + StringRef nameStr; + if (!resultName.first.empty() && !isdigit(resultName.first[0])) + nameStr = resultName.first; + + names.push_back(nameStr); + } + + auto namesAttr = parser.getBuilder().getStrArrayAttr(names); + result.attributes.push_back({Identifier::get("names", context), namesAttr}); + return success(); +} + +static void print(OpAsmPrinter &p, StringAttrPrettyNameOp op) { + p << "test.string_attr_pretty_name"; + + // Note that we only need to print the "name" attribute if the asmprinter + // result name disagrees with it. This can happen in strange cases, e.g. + // when there are conflicts. + bool namesDisagree = op.names().size() != op.getNumResults(); + + SmallString<32> resultNameStr; + for (size_t i = 0, e = op.getNumResults(); i != e && !namesDisagree; ++i) { + resultNameStr.clear(); + llvm::raw_svector_ostream tmpStream(resultNameStr); + p.printOperand(op.getResult(i), tmpStream); + + auto expectedName = op.names()[i].dyn_cast(); + if (!expectedName || + tmpStream.str().drop_front() != expectedName.getValue()) { + namesDisagree = true; + } + } + + if (namesDisagree) + p.printOptionalAttrDictWithKeyword(op.getAttrs()); + else + p.printOptionalAttrDictWithKeyword(op.getAttrs(), {"names"}); +} + +// We set the SSA name in the asm syntax to the contents of the name +// attribute. +void StringAttrPrettyNameOp::getAsmResultNames( + function_ref setNameFn) { + + auto value = names(); + for (size_t i = 0, e = value.size(); i != e; ++i) + if (auto str = value[i].dyn_cast()) + if (!str.getValue().empty()) + setNameFn(getResult(i), str.getValue()); +} + +//===----------------------------------------------------------------------===// // Dialect Registration //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td --- a/mlir/test/lib/TestDialect/TestOps.td +++ b/mlir/test/lib/TestDialect/TestOps.td @@ -496,6 +496,18 @@ ); } +// This is used to test encoding of a string attribute into an SSA name of a +// pretty printed value name. +def StringAttrPrettyNameOp + : TEST_Op<"string_attr_pretty_name", + [DeclareOpInterfaceMethods]> { + let arguments = (ins StrArrayAttr:$names); + let results = (outs Variadic:$r); + + let printer = [{ return ::print(p, *this); }]; + let parser = [{ return ::parse$cppClass(parser, result); }]; +} + //===----------------------------------------------------------------------===// // Test Patterns //===----------------------------------------------------------------------===//