diff --git a/mlir/docs/Dialects/Linalg.md b/mlir/docs/Dialects/Linalg.md --- a/mlir/docs/Dialects/Linalg.md +++ b/mlir/docs/Dialects/Linalg.md @@ -608,10 +608,18 @@ perform multiple updates. 2. Each tensor may only be used with a single indexing expression. +A `"""`-wrapped doc string can be attched to the named op. It should contain +a oneliner for summary first, followed by lenghty description. + The following specification may be used to define a named `batchmatmul` op: ``` -def batchmatmul(A: f32(Batch, M, K), B: f32(K, N)) -> (C: f32(Batch, M, N)) { +def batchmatmul(A: f32(Batch, M, K), B: f32(K, N)) -> (C: f32(Batch, M, N)) +"""Batch matrix-multiply operation. + +This operation performs batch matrix-multiply over ... +""" +{ C(b, m, n) = std_addf(std_mulf(A(b, m, k), B(k, n))); } ``` diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc --- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc @@ -125,3 +125,22 @@ O(n, h, w, f) = std_addf(std_mulf( I(n, h * strides[0] + kh, w * strides[1] + kw, c), K(f, kh, kw, c))); } + +// ODS-LABEL: def Test6Op +// ODS: let summary = [{ My magic op. }]; +// ODS-NEXT: let description = [{ +// ODS-NEXT: It has two inputs. +// ODS-NEXT: It has one output. +// ODS-NEXT: }]; +// +ods_def: +def test6(A: f32(M, K), B: f32(K)) -> (C: f32(M)) +""" +My magic op. + +It has two inputs. +It has one output. +""" +{ + C(m) = std_addf(std_mulf(A(m, k), B(k))); +} diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp @@ -30,6 +30,7 @@ #include "llvm/ADT/Twine.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/ToolOutputFile.h" @@ -85,6 +86,7 @@ // Tokens with no info. colon, comma, + doc_str, equal, gt, l_brace, @@ -183,6 +185,9 @@ // Lex an integer. Token lexInteger(const char *tokStart); + // Lex a string. + Token lexString(const char *tokStart); + // Skip a comment line, starting with a '//'. void skipComment(); @@ -287,6 +292,8 @@ return formToken(Token::Kind::star, tokStart); case '?': return formToken(Token::Kind::question, tokStart); + case '"': + return lexString(tokStart); case '/': if (*curPtr == '/') { skipComment(); @@ -333,6 +340,36 @@ return Token(Token::Kind::integer, str); } +Token Lexer::lexString(const char *tokStart) { + assert(curPtr[-1] == '"'); + + if (*curPtr == '"' && *(curPtr + 1) == '"') { + curPtr += 2; + while (true) { + switch (*curPtr++) { + case '"': + if (*curPtr == '"' && *(curPtr + 1) == '"') { + Token token(Token::Kind::doc_str, + StringRef(tokStart + 3, curPtr - tokStart - 4)); + curPtr += 2; + return token; + } + continue; + case 0: + // If this is a random nul character in the middle of the doc string, + // just include it. If it is the end of file, then it is an error. + if (curPtr - 1 != curBuffer.end()) + continue; + return emitError(curPtr - 1, "expected '\"\"\"' to end doc string"); + default: + continue; + } + } + } + + return emitError(curPtr - 1, "expected '\"\"\"' to start doc string"); +} + /// Skip a comment line, starting with a '//'. void Lexer::skipComment() { // Advance over the second '/' in a '//' comment. @@ -1133,6 +1170,8 @@ /// Attributes are per TC def. std::map registeredAttrs; + StringRef docString; + Parser &parser; }; } // namespace @@ -1654,6 +1693,14 @@ return failure(); } + // Parse optional doc string + if (parser.curToken.is(Token::Kind::doc_str)) { + docString = parser.curToken.getSpelling(); + parser.consumeToken(); + LLVM_DEBUG(llvm::dbgs() + << "parsed doc string: '''" << docString << "'''\n"); + } + // Since we don't declare symbols separately, we discover them eagerly: each // newly encountered id in a tensor shape expression is treated as a new // symbolic. At this point, all tensors have been parsed and all the symbols @@ -1754,9 +1801,10 @@ AttrSizedOperandSegments, DeclareOpInterfaceMethods, SingleBlockImplicitTerminator<"YieldOp">]> { + {2} let arguments = (ins Variadic:$inputs, - Variadic:$outputs{4} + Variadic:$outputs{3} ); let results = (outs Variadic:$result_tensors); let regions = (region AnyRegion:$region); @@ -1817,23 +1865,30 @@ static std::function getRegionBuilder() {{ return regionBuilder; } // Generic methods. - static unsigned getNumRegionArgs() {{ return {5}; } + static unsigned getNumRegionArgs() {{ return {4}; } std::string getLibraryCallName() {{ return generateLibraryCallName(getOperation()); } }]; })FMT"; - unsigned nInputs = 0, nOutputs = 0; - for (auto &t : registeredTensors) { - if (t.getValue().isOutput) - nOutputs++; - else - nInputs++; + std::string doc; + + if (!docString.empty()) { + const char *docFmt = R"FMT( + let summary = [{ {0} }]; + let description = [{ + {1} + }]; + )FMT"; + + StringRef summary, description; + std::tie(summary, description) = docString.trim().split('\n'); + doc = llvm::formatv(docFmt, summary.trim(), description.trim()); } - os << llvm::formatv(header, cppOpName, linalgOpName, nInputs, nOutputs, - attrList, state.orderedTensorArgs.size()); + os << llvm::formatv(header, cppOpName, linalgOpName, doc, attrList, + state.orderedTensorArgs.size()); } /// Print the C++ StructuredOpsInterface impl of `iterator_types`.