diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc --- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc @@ -72,3 +72,25 @@ def test3(A: f32(Batch, M, K), B: f32(K, N)) -> (C: f32(Batch, M, N)) { C(b, m, n) = std_addf(std_mulf(A(b, m, k), B(k, n))); } + +// Test attribute definitions +// ODS-LABEL: def Test4Op +// ODS: F32ArrayAttr:$array_attr, +// ODS: F32:$f32_attr, +// ODS: RankedF32ElementsAttr<[4]>:$fvec_attr, +// ODS: I32:$i32_attr, +// ODS: RankedI32ElementsAttr<[5, 6]>:$ivec_attr, +// ODS: OptionalAttr:$optional_attr +// +ods_def : +def test4(A: f32(Batch, M, K), B: f32(K, N)) -> (C: f32(Batch, M, N)) +attr( + f32_attr: f32, + i32_attr: i32, + fvec_attr: 4xf32, + ivec_attr: 5x6xi32, + array_attr : f32[], + optional_attr? : f32 +) { + C(b, m, n) = std_addf(std_mulf(A(b, m, k), B(k, n))); +} diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp @@ -20,11 +20,17 @@ #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/StringSwitch.h" +#include "llvm/ADT/Twine.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/ToolOutputFile.h" +#include + #define DEBUG_TYPE "linalg-ods-gen" static llvm::cl::OptionCategory ODSGenCat("Linalg ODS Gen"); @@ -79,11 +85,14 @@ gt, l_brace, l_paren, + l_square, lt, minus, plus, + question, r_brace, r_paren, + r_square, semicolon, star, @@ -91,6 +100,7 @@ kw_def, FIRST_KEYWORD = kw_def, kw_ods_def, + kw_attr_def, kw_floordiv, kw_ceildiv, kw_mod, @@ -151,6 +161,10 @@ Token emitError(llvm::SMLoc loc, const Twine &msg); Token emitError(const char *loc, const Twine &msg); + /// Change the position of the lexer cursor. The next token we lex will start + /// at the designated point in the input. + void resetPointer(const char *newPtr) { curPtr = newPtr; } + private: Token formToken(Token::Kind kind, const char *tokStart) { return Token(kind, StringRef(tokStart, curPtr - tokStart)); @@ -247,10 +261,14 @@ return formToken(Token::Kind::l_brace, tokStart); case '(': return formToken(Token::Kind::l_paren, tokStart); + case '[': + return formToken(Token::Kind::l_square, tokStart); case '}': return formToken(Token::Kind::r_brace, tokStart); case ')': return formToken(Token::Kind::r_paren, tokStart); + case ']': + return formToken(Token::Kind::r_square, tokStart); case '<': return formToken(Token::Kind::lt, tokStart); case '>': @@ -263,6 +281,8 @@ return formToken(Token::Kind::semicolon, tokStart); case '*': return formToken(Token::Kind::star, tokStart); + case '?': + return formToken(Token::Kind::question, tokStart); case '/': if (*curPtr == '/') { skipComment(); @@ -289,6 +309,7 @@ // Check to see if this identifier is a keyword. StringRef str(tokStart, curPtr - tokStart); Token::Kind kind = StringSwitch(str) + .Case("attr", Token::Kind::kw_attr_def) .Case("def", Token::Kind::kw_def) .Case("ods_def", Token::Kind::kw_ods_def) .Case("floordiv", Token::Kind::kw_floordiv) @@ -352,29 +373,40 @@ "shouldn't advance past EOF or errors"); curToken = lexer.lexToken(); } + void consumeToken(Token::Kind kind) { assert(curToken.getKind() == kind && "unexpected token"); curToken = lexer.lexToken(); } + LogicalResult parseToken(Token::Kind kind, const Twine &msg) { if (curToken.getKind() != kind) return emitError(curToken.getLoc(), msg); consumeToken(); return success(); } + + /// Parses an optional token and returns failure if failed to parse. + LogicalResult parseOptionalToken(Token::Kind kind) { + return success(consumeIf(kind)); + } + LogicalResult emitError(llvm::SMLoc loc, const Twine &msg) { lexer.emitError(loc, msg); return failure(); } + LogicalResult emitError(const Twine &msg) { return emitError(curToken.getLoc(), msg); } + bool consumeIf(Token::Kind kind) { if (curToken.isNot(kind)) return false; consumeToken(kind); return true; } + LogicalResult parseCommaSeparatedList(llvm::function_ref parseElement) { // Non-empty case starts with an element. @@ -388,6 +420,7 @@ } return success(); } + LogicalResult parseCommaSeparatedListUntil(Token::Kind rightToken, llvm::function_ref parseElement, @@ -961,6 +994,8 @@ LogicalResult parseTensorUse(TensorUse &result, ComprehensionParsingState &state); + LogicalResult parseAttrDef(); + /// Parses a tensor expression. LogicalResult parseExpression(TensorUse currentDefinition, std::unique_ptr &result, @@ -1010,15 +1045,29 @@ unsigned index; }; + //===--------------------------------------------------------------------===// + // Internal bookkeeping of attributes. + //===--------------------------------------------------------------------===// + struct RegisteredAttr { + StringRef elementType; + SmallVector vectorDims; + bool isArray; + bool isOptional; + }; + //===--------------------------------------------------------------------===// // Per-TC def state. //===--------------------------------------------------------------------===// /// Symbols are per TC def. AffineSymbolList symbols; + /// Tensors are per TC def. llvm::StringMap registeredTensors; unsigned nextRegisteredTensorIndex; + /// Attributes are per TC def. + std::map registeredAttrs; + Parser &parser; }; } // namespace @@ -1170,6 +1219,72 @@ return success(); } +/// Parse the information for an attribute def of the form: +/// +/// affine-expr-list ::= affine-expr (`,` affine-expr )* +/// attr-id ::= bare-id (`?`)? +/// dim-list ::= (integer-literal 'x')+ +/// attr-typedef ::= dim-list? type (`[` `]`)? +/// attr-def ::= attr-id `:` attr-typedef +LogicalResult TCParser::parseAttrDef() { + auto attrLoc = parser.curToken.getLoc(); + StringRef attrName = parser.curToken.getSpelling(); + if (failed(parser.parseToken(Token::Kind::id, "expected an id"))) + return failure(); + bool isOptional = succeeded(parser.parseOptionalToken(Token::Kind::question)); + if (failed(parser.parseToken(Token::Kind::colon, "expected colon"))) + return failure(); + + // Parse the attribute's type. We don't expect the type to be arbitrary + // complex, so just use this ad-hoc handling here. + + // Parse potential dimension list + SmallVector vectorDims; + while (parser.curToken.is(Token::Kind::integer)) { + vectorDims.push_back(parser.curToken.getUInt64IntegerValue().getValue()); + parser.consumeToken(); + + StringRef spelling = parser.curToken.getSpelling(); + if (spelling[0] != 'x') + return parser.emitError(parser.curToken.getLoc(), + "expected 'x' in dimension list"); + + // If we had a prefix of 'x', lex the next token immediately after the 'x'. + if (spelling.size() != 1) + parser.lexer.resetPointer(spelling.data() + 1); + + parser.consumeToken(); + } + + StringRef elementType = parser.curToken.getSpelling(); + if (failed(parser.parseToken(Token::Kind::id, "expected an id"))) + return failure(); + + bool isArray = false; + auto arrayLoc = parser.curToken.getLoc(); + if (succeeded(parser.parseOptionalToken(Token::Kind::l_square))) { + isArray = true; + if (failed(parser.parseToken(Token::Kind::r_square, "expected ']'"))) + return failure(); + } + + if (!vectorDims.empty() && isArray) + return parser.emitError(arrayLoc, "unsupported vector array attribute"); + + auto iterBoolPair = registeredAttrs.emplace( + attrName, RegisteredAttr{elementType, vectorDims, isArray, isOptional}); + if (!iterBoolPair.second) + return parser.emitError(attrLoc, + "Failed to register attribute '" + attrName + "'"); + + LLVM_DEBUG(llvm::dbgs() << "Recorded: " << (isOptional ? "[optional]" : "") + << " " << attrName << " " + << "with type: " << elementType + << (isArray ? "[]" : "") << "\n"); + + return success(); +} + /// Parses a tensor expression of the form: /// /// op-spec ::= bare-id `<` reduction-dims-list `>` @@ -1341,10 +1456,13 @@ /// Parse and print the information for a ODS def. /// /// tensor-def-list ::= tensor-def (`,` tensor-def )* +/// attr-def-list ::= attr-def (`,` attr-def )* /// /// comprehension-list ::= comprehension comprehension* /// +/// tc-attr-def ::= `attr` `(` attr-def-list `)` /// tc-def ::= `def` bare-id `(`tensor-def-list`)` `->` `(` tensor-def-list`)` +/// (tc-attr-def)? /// `{` comprehension-list `}` /// /// ods-def ::= `ods_def` `<` bare-id `>` `:` tc-def @@ -1353,6 +1471,7 @@ /// contain only expressions involving symbols and constants), but can /// otherwise contain arbitrary affine expressions. LogicalResult TCParser::parseAndEmitODSDef(llvm::raw_ostream &os) { + // Parse def header (including C++ op name) if (failed(parser.parseToken(Token::Kind::kw_ods_def, "expected 'ods_def' to define a TC ODS")) || failed(parser.parseToken(Token::Kind::lt, "expected '<'"))) @@ -1364,12 +1483,15 @@ failed(parser.parseToken(Token::Kind::gt, "expected '>'")) || failed(parser.parseToken(Token::Kind::colon, "expected ':'"))) return failure(); + if (failed(parser.parseToken(Token::Kind::kw_def, "expected 'def' to define a TC"))) return failure(); StringRef tcName = parser.curToken.getSpelling(); LLVM_DEBUG(llvm::dbgs() << "\n\nStart parsing TC: " << tcName << "\n"); + + // Parse input/output tensor definitions if (failed(parser.parseToken(Token::Kind::id, "expected id")) || failed(parser.parseToken(Token::Kind::l_paren, "expected '('"))) return failure(); @@ -1392,6 +1514,16 @@ Token::Kind::r_paren, parseOutputDef, /*allowEmptyList=*/false))) return failure(); + // Parse optional attribute definitions + if (succeeded(parser.parseOptionalToken(Token::Kind::kw_attr_def))) { + if (failed(parser.parseToken(Token::Kind::l_paren, "expected '('"))) + return failure(); + if (failed(parser.parseCommaSeparatedListUntil( + Token::Kind::r_paren, std::bind(&TCParser::parseAttrDef, this), + /*allowEmptyList=*/false))) + return failure(); + } + // Since we don't declare symbols separately, we discover them eagerly: each // newly encountered id in a tensor shape expression is treated as a new // symbolic. At this point, all tensors have been parsed and all the symbols @@ -1450,12 +1582,52 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName, StringRef linalgOpName, ComprehensionParsingState &state) { + SmallVector attributes; + for (const auto &attr : registeredAttrs) { + llvm::StringRef name = attr.first; + + llvm::StringRef elementType = attr.second.elementType; + std::string odsType = llvm::StringSwitch(elementType) + .Case("f32", "F32") + .Case("i32", "I32") + .Default(""); + if (odsType.empty()) { + parser.emitError("unimplemented support for attribute element type: " + + elementType); + return; + } + + const auto &dims = attr.second.vectorDims; + if (!dims.empty()) { + SmallVector dimStrs; + for (uint64_t dim : dims) + dimStrs.push_back(std::to_string(dim)); + odsType = llvm::formatv("Ranked{0}ElementsAttr<[{1}]>", odsType, + llvm::join(dimStrs, ", ")); + } + + assert(dims.empty() || !attr.second.isArray); + if (attr.second.isArray) + odsType = llvm::formatv("{0}ArrayAttr", odsType); + + if (attr.second.isOptional) + odsType = llvm::formatv("OptionalAttr<{0}>", odsType); + + attributes.push_back(llvm::formatv("{0}:${1}", odsType, name)); + } + + std::string attrList = llvm::join(attributes, ",\n"); + if (!attrList.empty()) + attrList = ",\n" + attrList; + const char *header = R"FMT( def {0} : LinalgStructuredBase_Op<"{1}", [ AttrSizedOperandSegments, DeclareOpInterfaceMethods, SingleBlockImplicitTerminator<"YieldOp">]> { - let arguments = (ins Variadic:$inputs, - Variadic:$outputs); + let arguments = (ins + Variadic:$inputs, + Variadic:$outputs{4} + ); let results = (outs Variadic:$result_tensors); let regions = (region AnyRegion:$region); @@ -1515,7 +1687,7 @@ static std::function getRegionBuilder() {{ return regionBuilder; } // Generic methods. - static unsigned getNumRegionArgs() {{ return {4}; } + static unsigned getNumRegionArgs() {{ return {5}; } std::string getLibraryCallName() {{ return generateLibraryCallName(getOperation()); } @@ -1531,7 +1703,7 @@ } os << llvm::formatv(header, cppOpName, linalgOpName, nInputs, nOutputs, - state.orderedTensorArgs.size()); + attrList, state.orderedTensorArgs.size()); } /// Print the C++ StructuredOpsInterface impl of `iterator_types`.