diff --git a/mlir/docs/Dialects/Linalg.md b/mlir/docs/Dialects/Linalg.md --- a/mlir/docs/Dialects/Linalg.md +++ b/mlir/docs/Dialects/Linalg.md @@ -451,6 +451,93 @@ This is the main reason there are only a small number of ops today: we expect them to be auto-generated from Tablegen soon. +### Named Payload Ops Specification + +Linalg provides a declarative specification and a generation tool +(`mlir-linalg-ods-gen`) to automatically produce named ops from a notation that +is inspired by Einstein notation. + +The syntax and semantics used in `mlir-linalg-ods-gen` are very much in flight +and borrow from Tensor Comprehensions (TC) but differ in a few dimensions, to +better adapt to Linalg: + +1. The input and output tensor parameters are specified as `id : + type(symbolic-affine-expression-list)` (e.g. `A : f32(M, N + M)`) and each + new symbol is discovered eagerly. TC on the other hand does not allow + general symbolic affine expressions. +1. The output shapes are specified explicitly, in TC they are always derived + from the input shapes. +1. The operations used to specify computations use EDSC intrinsics so that they + can easily be parsed and emitted into a simple region builder without + resorting to more general MLIR parsing. +1. Reduction dimensions are specified with angle bracket notation on the + operation they apply to (e.g. `std_add` specifies that `k` is a reduction + dimension). In TC, a reduction is specified with `op=` operator and the + reduction dimensions are inferred. +1. The parallel and reduction dimension are ordered by the textual program + order. For instance, in the comprehension `O(i, j) = std_add(...)`, + `i` (resp. `j`) is a parallel iterator encoded by affine dimension of + position `0` (resp. `1`); `k` (resp. `l`) is a reduction iterator encoded by + an affine dimension of position `2` (resp. `3`). + +These decisions and syntax are subject to evolution and change. In particular, +op-specific attributes, dynamic ranks, some form of templating, shape +calculation function specification, etc. may be added in the future. + +At this time, the following restrictions are imposed on the syntax and +semantics: + +1. Each def may only contain a single comprehension but each comprehension may + perform multiple updates. +2. Each tensor may only be used with a single indexing expression. + +The following specification may be used to define a named `batchmatmul` op: + +``` +def batchmatmul(A: f32(Batch, M, K), B: f32(K, N)) -> (C: f32(Batch, M, N)) { + C(b, m, n) = std_addf(std_mulf(A(b, m, k), B(k, n))); +} +``` + +When `mlir-linalg-ods-gen -gen-ods-decl=1` is called, the following ODS is +produced: + +``` + def batchmatmulOp : LinalgNamedStructured_Op<"batchmatmul", [ + NInputs<2>, + NOutputs<1>, + NamedStructuredOpTraits]> { ... } +``` + +When `mlir-linalg-ods-gen -gen-impl=1` is called, the following C++ is produced: + +``` +llvm::Optional> batchmatmul::referenceIterators() { + return SmallVector{ + getParallelIteratorTypeName(), + getParallelIteratorTypeName(), + getParallelIteratorTypeName(), + getReductionIteratorTypeName() }; +} +llvm::Optional> batchmatmul::referenceIndexingMaps() { + MLIRContext *context = getContext(); + AffineExpr d0, d1, d2, d3; + bindDims(context, d0, d1, d2, d3); + return SmallVector{ + AffineMap::get(4, 0, {d0, d1, d3}), + AffineMap::get(4, 0, {d3, d2}), + AffineMap::get(4, 0, {d0, d1, d2}) }; +} +void batchmatmul::regionBuilder(ArrayRef args) { + using namespace edsc; + using namespace intrinsics; + ValueHandle _0(args[0]), _1(args[1]), _2(args[2]); + ValueHandle _4 = std_mulf(_0, _1); + ValueHandle _5 = std_addf(_2, _4); + (linalg_yield(ValueRange{ _5 })); +} +``` + ## Open Issues and Design Alternatives Multiple open issues and design alternatives are in flight and it is time to lay them out for the community to discuss and pick apart: diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -256,7 +256,7 @@ /// OptionalAttr:$strides /// OptionalAttr:$dilations /// OptionalAttr:$padding -/// `strides` denotes the step of each window along the dimension. +/// `stirdes` denotes the step of each window along the dimension. class PoolingBase_Op props> : LinalgStructured_Op { let description = [{ @@ -821,4 +821,18 @@ let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// Named Linalg ops, implemented as a declarative configurations of generic ops. +//===----------------------------------------------------------------------===// + +def NamedStructuredOpTraits : NativeOpTrait<"linalg::NamedStructuredOpTraits">; + +class LinalgNamedStructured_Op props> + : Op { + string spec = ?; + let assemblyFormat = "`(` operands `)` attr-dict `:` " + "functional-type(operands, results)"; +} + #endif // LINALG_STRUCTURED_OPS diff --git a/mlir/include/mlir/IR/AffineExpr.h b/mlir/include/mlir/IR/AffineExpr.h --- a/mlir/include/mlir/IR/AffineExpr.h +++ b/mlir/include/mlir/IR/AffineExpr.h @@ -219,7 +219,7 @@ ArrayRef localExprs, MLIRContext *context); -raw_ostream &operator<<(raw_ostream &os, AffineExpr &expr); +raw_ostream &operator<<(raw_ostream &os, AffineExpr expr); template bool AffineExpr::isa() const { if (std::is_same::value) diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp --- a/mlir/lib/IR/AffineExpr.cpp +++ b/mlir/lib/IR/AffineExpr.cpp @@ -613,7 +613,7 @@ map.getResults().end()); return replaceDimsAndSymbols(dimReplacements, {}); } -raw_ostream &mlir::operator<<(raw_ostream &os, AffineExpr &expr) { +raw_ostream &mlir::operator<<(raw_ostream &os, AffineExpr expr) { expr.print(os); return os; } diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt --- a/mlir/test/CMakeLists.txt +++ b/mlir/test/CMakeLists.txt @@ -35,6 +35,7 @@ MLIRUnitTests mlir-cpu-runner mlir-edsc-builder-api-test + mlir-linalg-ods-gen mlir-opt mlir-sdbm-api-test mlir-tblgen diff --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py --- a/mlir/test/lit.cfg.py +++ b/mlir/test/lit.cfg.py @@ -21,7 +21,7 @@ config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell) # suffixes: A list of file extensions to treat as test files. -config.suffixes = ['.td', '.mlir', '.toy', '.ll'] +config.suffixes = ['.td', '.mlir', '.toy', '.ll', '.tc'] # test_source_root: The root path where tests are located. config.test_source_root = os.path.dirname(__file__) diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc @@ -0,0 +1,75 @@ +// RUN: mlir-linalg-ods-gen %s -gen-ods-decl=1 | FileCheck %s --check-prefix=ODS +// RUN: mlir-linalg-ods-gen %s -gen-impl=1 | FileCheck %s --check-prefix=IMPL + +// RUN: mlir-linalg-ods-gen %s -gen-ods-decl=1 -test-emit-include-td-header \ +// RUN: | mlir-tblgen -gen-op-decls -I %S/../../include + +// ODS-LABEL: def matvecOp : LinalgNamedStructured_Op<"matvec", [ +// ODS-NEXT: NInputs<2>, +// ODS-NEXT: NOutputs<1>, +// ODS-NEXT: NamedStructuredOpTraits]> +// +// IMPL-LABEL: matvec::referenceIterators() { +// IMPL-NEXT: { {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} } +// +// IMPL: matvec::referenceIndexingMaps() { +// IMPL: AffineMap::get(2, 0, {d0, d1}), +// IMPL-NEXT: AffineMap::get(2, 0, {d1}), +// IMPL-NEXT: AffineMap::get(2, 0, {d0}) }; +// +// IMPL: matvec::regionBuilder(ArrayRef args) { +// IMPL: ValueHandle [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]); +// IMPL: ValueHandle [[d:.*]] = std_mulf([[a]], [[b]]); +// IMPL: ValueHandle [[e:.*]] = std_addf([[c]], [[d]]); +// IMPL: (linalg_yield(ValueRange{ [[e]] })); +// +def matvec(A: f32(M, K), B: f32(K)) -> (C: f32(M)) { + C(m) = std_addf(std_mulf(A(m, k), B(k))); +} + +// ODS-LABEL: def matmulOp : LinalgNamedStructured_Op<"matmul", [ +// ODS-NEXT: NInputs<2>, +// ODS-NEXT: NOutputs<1>, +// ODS-NEXT: NamedStructuredOpTraits]> +// +// IMPL-LABEL: matmul::referenceIterators() { +// IMPL-NEXT: { {{.*}}Parallel{{.*}}, {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} } +// +// IMPL: matmul::referenceIndexingMaps() { +// IMPL: AffineMap::get(3, 0, {d0, d2}), +// IMPL-NEXT: AffineMap::get(3, 0, {d2, d1}), +// IMPL-NEXT: AffineMap::get(3, 0, {d0, d1}) }; +// +// IMPL: matmul::regionBuilder(ArrayRef args) { +// IMPL: ValueHandle [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]); +// IMPL: ValueHandle [[d:.*]] = std_mulf([[a]], [[b]]); +// IMPL: ValueHandle [[e:.*]] = std_addf([[c]], [[d]]); +// IMPL: (linalg_yield(ValueRange{ [[e]] })); +// +def matmul(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) { + C(m, n) = std_addf(std_mulf(A(m, k), B(k, n))); +} + +// ODS-LABEL: def batchmatmulOp : LinalgNamedStructured_Op<"batchmatmul", [ +// ODS-NEXT: NInputs<2>, +// ODS-NEXT: NOutputs<1>, +// ODS-NEXT: NamedStructuredOpTraits]> +// +// IMPL-LABEL: batchmatmul::referenceIterators() { +// IMPL-NEXT: { {{.*}}Parallel{{.*}}, {{.*}}Parallel{{.*}}, {{.*}}Reduction{{.*}} } +// +// IMPL: batchmatmul::referenceIndexingMaps() { +// IMPL: AffineMap::get(4, 0, {d0, d1, d3}), +// IMPL-NEXT: AffineMap::get(4, 0, {d3, d2}), +// IMPL-NEXT: AffineMap::get(4, 0, {d0, d1, d2}) }; +// +// IMPL: batchmatmul::regionBuilder(ArrayRef args) { +// IMPL: ValueHandle [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]); +// IMPL: ValueHandle [[d:.*]] = std_mulf([[a]], [[b]]); +// IMPL: ValueHandle [[e:.*]] = std_addf([[c]], [[d]]); +// IMPL: (linalg_yield(ValueRange{ [[e]] })); +// +// TBLGEN: batchmatmulOp +def batchmatmul(A: f32(Batch, M, K), B: f32(K, N)) -> (C: f32(Batch, M, N)) { + C(b, m, n) = std_addf(std_mulf(A(b, m, k), B(k, n))); +} diff --git a/mlir/tools/CMakeLists.txt b/mlir/tools/CMakeLists.txt --- a/mlir/tools/CMakeLists.txt +++ b/mlir/tools/CMakeLists.txt @@ -1,5 +1,6 @@ add_subdirectory(mlir-cuda-runner) add_subdirectory(mlir-cpu-runner) +add_subdirectory(mlir-linalg-ods-gen) add_subdirectory(mlir-opt) add_subdirectory(mlir-translate) add_subdirectory(mlir-vulkan-runner) diff --git a/mlir/tools/mlir-linalg-ods-gen/CMakeLists.txt b/mlir/tools/mlir-linalg-ods-gen/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-linalg-ods-gen/CMakeLists.txt @@ -0,0 +1,10 @@ +add_llvm_tool(mlir-linalg-ods-gen + mlir-linalg-ods-gen.cpp +) +llvm_update_compile_flags(mlir-linalg-ods-gen) +target_link_libraries(mlir-linalg-ods-gen PRIVATE + MLIRParser + MLIRSupport + LLVMCore + LLVMSupport + ) diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp @@ -0,0 +1,1659 @@ +//===- mlir-linalg-ods-gen.cpp - Linalg ODS generation from math form -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the implementation for the Tensor Comprehension-inspired +// parser and ODS pretty-printer for specifying Linalg "named ops" from a +// mathematical form. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/Support/FileUtilities.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Support/STLExtras.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/ToolOutputFile.h" + +#define DEBUG_TYPE "linalg-ods-gen" + +static llvm::cl::OptionCategory ODSGenCat("Linalg ODS Gen"); + +// Commandline options +static llvm::cl::opt + inputFilename(llvm::cl::Positional, llvm::cl::desc(""), + llvm::cl::init("-"), llvm::cl::value_desc("filename")); + +static llvm::cl::opt + outputFilename("o", llvm::cl::desc("Output filename"), + llvm::cl::value_desc("filename"), llvm::cl::init("-")); + +static llvm::cl::opt + genODSDecl("gen-ods-decl", llvm::cl::desc("Emit the ODS ops declarations."), + llvm::cl::cat(ODSGenCat)); + +static llvm::cl::opt + genODSImpl("gen-impl", llvm::cl::desc("Emit the ops implementations"), + llvm::cl::init(false), llvm::cl::cat(ODSGenCat)); + +static llvm::cl::opt testEmitIncludeTdHeader( + "test-emit-include-td-header", + llvm::cl::desc("Include LinalgStructuredOps.td for end-to-end " + "tblgen testing."), + llvm::cl::init(false), llvm::cl::cat(ODSGenCat)); + +using llvm::SetVector; +using llvm::SMLoc; +using llvm::StringRef; +using llvm::Twine; + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Lexer +//===----------------------------------------------------------------------===// + +namespace { +/// This class represents a specific token in the input format. +class Token { +public: + enum class Kind { + // Markers. + eof, + error, + + // Tokens with no info. + colon, + comma, + equal, + gt, + l_brace, + l_paren, + lt, + minus, + plus, + r_brace, + r_paren, + semicolon, + star, + + // Keywords. + kw_def, + FIRST_KEYWORD = kw_def, + kw_floordiv, + kw_ceildiv, + kw_mod, + LAST_KEYWORD = kw_mod, + + // String valued tokens. + id, + integer, + }; + + Token(Kind kind, StringRef spelling) : kind(kind), spelling(spelling) {} + + /// Return the bytes that make up this token. + StringRef getSpelling() const { return spelling; } + + /// Return the kind of this token. + Kind getKind() const { return kind; } + + /// Return a location for this token. + llvm::SMLoc getLoc() const { + return llvm::SMLoc::getFromPointer(spelling.data()); + } + + /// Return if this token is a keyword. + bool isKeyword() const { + return kind >= Kind::FIRST_KEYWORD && kind <= Kind::LAST_KEYWORD; + } + bool is(Kind k) const { return kind == k; } + bool isNot(Kind k) const { return kind != k; } + + Optional getUInt64IntegerValue() const { + bool isHex = spelling.size() > 1 && spelling[1] == 'x'; + + uint64_t result = 0; + if (spelling.getAsInteger(isHex ? 0 : 10, result)) + return None; + return result; + } + +private: + /// Discriminator that indicates the kind of token this is. + Kind kind; + + /// A reference to the entire token contents; this is always a pointer into + /// a memory buffer owned by the source manager. + StringRef spelling; +}; + +/// This class implements a simple lexer. +class Lexer { +public: + Lexer(llvm::SourceMgr &mgr); + + /// Lex the next token and return it. + Token lexToken(); + + /// Emit an error to the lexer with the given location and message. + Token emitError(llvm::SMLoc loc, const Twine &msg); + Token emitError(const char *loc, const Twine &msg); + +private: + Token formToken(Token::Kind kind, const char *tokStart) { + return Token(kind, StringRef(tokStart, curPtr - tokStart)); + } + + /// Return the next character in the stream. + int getNextChar(); + + /// Lex an identifier. + Token lexIdentifier(const char *tokStart); + + // Lex an integer. + Token lexInteger(const char *tokStart); + + // Skip a comment line, starting with a '//'. + void skipComment(); + + llvm::SourceMgr &srcMgr; + StringRef curBuffer; + const char *curPtr; +}; +} // end anonymous namespace + +Lexer::Lexer(llvm::SourceMgr &mgr) : srcMgr(mgr) { + curBuffer = srcMgr.getMemoryBuffer(mgr.getMainFileID())->getBuffer(); + curPtr = curBuffer.begin(); +} + +Token Lexer::emitError(llvm::SMLoc loc, const Twine &msg) { + srcMgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg); + return formToken(Token::Kind::error, loc.getPointer()); +} +Token Lexer::emitError(const char *loc, const Twine &msg) { + return emitError(llvm::SMLoc::getFromPointer(loc), msg); +} + +int Lexer::getNextChar() { + char curChar = *curPtr++; + switch (curChar) { + default: + return (unsigned char)curChar; + case 0: { + // A nul character in the stream is either the end of the current buffer + // or a random nul in the file. Disambiguate that here. + if (curPtr - 1 != curBuffer.end()) + return 0; + + // Otherwise, return end of file. + --curPtr; + return EOF; + } + case '\n': + case '\r': + // Handle the newline character by ignoring it and incrementing the line + // count. However, be careful about 'dos style' files with \n\r in them. + // Only treat a \n\r or \r\n as a single line. + if ((*curPtr == '\n' || (*curPtr == '\r')) && *curPtr != curChar) + ++curPtr; + return '\n'; + } +} + +Token Lexer::lexToken() { + while (true) { + const char *tokStart = curPtr; + + // This always consumes at least one character. + int curChar = getNextChar(); + switch (curChar) { + default: + // Handle identifiers: [a-zA-Z_] + if (isalpha(curChar) || curChar == '_') + return lexIdentifier(tokStart); + + // Handle integers: [0-9] + if (isdigit(curChar)) + return lexInteger(tokStart); + + // Unknown character, emit an error. + return emitError(tokStart, "unexpected character"); + + case EOF: + // Return EOF denoting the end of lexing. + return formToken(Token::Kind::eof, tokStart); + + // Lex punctuation. + case ':': + return formToken(Token::Kind::colon, tokStart); + case ',': + return formToken(Token::Kind::comma, tokStart); + case '=': + return formToken(Token::Kind::equal, tokStart); + case '{': + return formToken(Token::Kind::l_brace, tokStart); + case '(': + return formToken(Token::Kind::l_paren, tokStart); + case '}': + return formToken(Token::Kind::r_brace, tokStart); + case ')': + return formToken(Token::Kind::r_paren, tokStart); + case '<': + return formToken(Token::Kind::lt, tokStart); + case '>': + return formToken(Token::Kind::gt, tokStart); + case '+': + return formToken(Token::Kind::plus, tokStart); + case '-': + return formToken(Token::Kind::minus, tokStart); + case ';': + return formToken(Token::Kind::semicolon, tokStart); + case '*': + return formToken(Token::Kind::star, tokStart); + case '/': + if (*curPtr == '/') { + skipComment(); + continue; + } + // Unknown character, emit an error. + return emitError(tokStart, "unexpected character: not a comment"); + + // Ignore whitespace characters. + case 0: + case ' ': + case '\t': + case '\n': + return lexToken(); + } + } +} + +Token Lexer::lexIdentifier(const char *tokStart) { + // Match the rest of the identifier regex: [0-9a-zA-Z_\-]* + while (isalnum(*curPtr) || *curPtr == '_' || *curPtr == '-') + ++curPtr; + + // Check to see if this identifier is a keyword. + StringRef str(tokStart, curPtr - tokStart); + Token::Kind kind = llvm::StringSwitch(str) + .Case("def", Token::Kind::kw_def) + .Case("floordiv", Token::Kind::kw_floordiv) + .Case("ceildiv", Token::Kind::kw_ceildiv) + .Case("mod", Token::Kind::kw_mod) + .Default(Token::Kind::id); + + return Token(kind, str); +} + +Token Lexer::lexInteger(const char *tokStart) { + // Match the rest of the identifier regex: [0-9a-zA-Z_\-]* + while (isdigit(*curPtr)) + ++curPtr; + + StringRef str(tokStart, curPtr - tokStart); + return Token(Token::Kind::integer, str); +} + +/// Skip a comment line, starting with a '//'. +void Lexer::skipComment() { + // Advance over the second '/' in a '//' comment. + assert(*curPtr == '/'); + ++curPtr; + + while (true) { + switch (*curPtr++) { + case '\n': + case '\r': + // Newline is end of comment. + return; + case 0: + // If this is the end of the buffer, end the comment. + if (curPtr - 1 == curBuffer.end()) { + --curPtr; + return; + } + LLVM_FALLTHROUGH; + default: + // Skip over other characters. + break; + } + } +} + +namespace { + +class Parser { +public: + Parser(llvm::SourceMgr &mgr, MLIRContext *ctx) + : lexer(mgr), curToken(lexer.lexToken()), context(ctx) {} + + //===--------------------------------------------------------------------===// + // Lexer Utilities + //===--------------------------------------------------------------------===// + + /// Advance the current lexer onto the next token. + void consumeToken() { + assert(curToken.getKind() != Token::Kind::eof && + curToken.getKind() != Token::Kind::error && + "shouldn't advance past EOF or errors"); + curToken = lexer.lexToken(); + } + void consumeToken(Token::Kind kind) { + assert(curToken.getKind() == kind && "unexpected token"); + curToken = lexer.lexToken(); + } + LogicalResult parseToken(Token::Kind kind, const Twine &msg) { + if (curToken.getKind() != kind) + return emitError(curToken.getLoc(), msg); + consumeToken(); + return success(); + } + LogicalResult emitError(llvm::SMLoc loc, const Twine &msg) { + lexer.emitError(loc, msg); + return failure(); + } + LogicalResult emitError(const Twine &msg) { + return emitError(curToken.getLoc(), msg); + } + bool consumeIf(Token::Kind kind) { + if (curToken.isNot(kind)) + return false; + consumeToken(kind); + return true; + } + LogicalResult + parseCommaSeparatedList(llvm::function_ref parseElement) { + // Non-empty case starts with an element. + if (parseElement()) + return failure(); + + // Otherwise we have a list of comma separated elements. + while (consumeIf(Token::Kind::comma)) { + if (parseElement()) + return failure(); + } + return success(); + } + LogicalResult + parseCommaSeparatedListUntil(Token::Kind rightToken, + llvm::function_ref parseElement, + bool allowEmptyList) { + // Handle the empty case. + if (curToken.is(rightToken)) { + if (!allowEmptyList) + return emitError("expected list element"); + consumeToken(rightToken); + return success(); + } + + if (failed(parseCommaSeparatedList(parseElement)) || + failed( + parseToken(rightToken, "expected ',' or right-terminating token"))) + return failure(); + + return success(); + } + + Lexer lexer; + Token curToken; + MLIRContext *context; +}; +} // namespace + +//===----------------------------------------------------------------------===// +// Affine parsing. +//===----------------------------------------------------------------------===// + +namespace { + +/// Lower precedence ops (all at the same precedence level). LNoOp is false in +/// the boolean sense. +enum AffineLowPrecOp { + /// Null value. + LNoOp, + Add, + Sub +}; + +/// Higher precedence ops - all at the same precedence level. HNoOp is false +/// in the boolean sense. +enum AffineHighPrecOp { + /// Null value. + HNoOp, + Mul, + FloorDiv, + CeilDiv, + Mod +}; + +using AffineDimList = SmallVector, 4>; +using AffineSymbolList = SmallVector, 4>; + +/// This is a specialized parser for affine expressions. +class AffineParser { +public: + explicit AffineParser(Parser &p, + std::function bareIdParsingHook, + AffineDimList &dimList, AffineSymbolList &symbolList) + : parser(p), bareIdFallback(bareIdParsingHook), dims(dimList), + symbols(symbolList) {} + + /// Parse a comma-separated list of affine exprs. + SmallVector + parseAffineExprs(Token::Kind lDelim = Token::Kind::l_paren, + Token::Kind rDelim = Token::Kind::r_paren); + + /// Parse a single affine expr.`. + AffineExpr parseAffineExpr(); + +private: + // Binary affine op parsing. + AffineLowPrecOp consumeIfLowPrecOp(); + AffineHighPrecOp consumeIfHighPrecOp(); + + // AffineExpr parsing. + AffineExpr parseParentheticalExpr(); + AffineExpr parseNegateExpression(AffineExpr lhs); + AffineExpr parseIntegerExpr(); + AffineExpr parseBareIdExpr(); + + AffineExpr getAffineBinaryOpExpr(AffineHighPrecOp op, AffineExpr lhs, + AffineExpr rhs, SMLoc opLoc); + AffineExpr getAffineBinaryOpExpr(AffineLowPrecOp op, AffineExpr lhs, + AffineExpr rhs); + AffineExpr parseAffineOperandExpr(AffineExpr lhs); + AffineExpr parseAffineLowPrecOpExpr(AffineExpr llhs, AffineLowPrecOp llhsOp); + AffineExpr parseAffineHighPrecOpExpr(AffineExpr llhs, AffineHighPrecOp llhsOp, + SMLoc llhsOpLoc); + + Parser &parser; + std::function bareIdFallback; + AffineDimList &dims; + AffineSymbolList &symbols; +}; +} // end anonymous namespace + +/// Create an affine binary high precedence op expression (mul's, div's, mod). +/// opLoc is the location of the op token to be used to report errors +/// for non-conforming expressions. +AffineExpr AffineParser::getAffineBinaryOpExpr(AffineHighPrecOp op, + AffineExpr lhs, AffineExpr rhs, + SMLoc opLoc) { + switch (op) { + case Mul: + if (!lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant()) { + parser.emitError(opLoc, + "non-affine expression: at least one of the multiply " + "operands has to be either a constant or symbolic"); + return nullptr; + } + return lhs * rhs; + case FloorDiv: + if (!rhs.isSymbolicOrConstant()) { + parser.emitError(opLoc, + "non-affine expression: right operand of floordiv " + "has to be either a constant or symbolic"); + return nullptr; + } + return lhs.floorDiv(rhs); + case CeilDiv: + if (!rhs.isSymbolicOrConstant()) { + parser.emitError(opLoc, "non-affine expression: right operand of ceildiv " + "has to be either a constant or symbolic"); + return nullptr; + } + return lhs.ceilDiv(rhs); + case Mod: + if (!rhs.isSymbolicOrConstant()) { + parser.emitError(opLoc, "non-affine expression: right operand of mod " + "has to be either a constant or symbolic"); + return nullptr; + } + return lhs % rhs; + case HNoOp: + llvm_unreachable("can't create affine expression for null high prec op"); + return nullptr; + } + llvm_unreachable("Unknown AffineHighPrecOp"); +} + +/// Create an affine binary low precedence op expression (add, sub). +AffineExpr AffineParser::getAffineBinaryOpExpr(AffineLowPrecOp op, + AffineExpr lhs, AffineExpr rhs) { + switch (op) { + case AffineLowPrecOp::Add: + return lhs + rhs; + case AffineLowPrecOp::Sub: + return lhs - rhs; + case AffineLowPrecOp::LNoOp: + llvm_unreachable("can't create affine expression for null low prec op"); + return nullptr; + } + llvm_unreachable("Unknown AffineLowPrecOp"); +} + +/// Consume this token if it is a lower precedence affine op (there are only +/// two precedence levels). +AffineLowPrecOp AffineParser::consumeIfLowPrecOp() { + switch (parser.curToken.getKind()) { + case Token::Kind::plus: + parser.consumeToken(); + return AffineLowPrecOp::Add; + case Token::Kind::minus: + parser.consumeToken(); + return AffineLowPrecOp::Sub; + default: + return AffineLowPrecOp::LNoOp; + } +} + +/// Consume this token if it is a higher precedence affine op (there are only +/// two precedence levels) +AffineHighPrecOp AffineParser::consumeIfHighPrecOp() { + switch (parser.curToken.getKind()) { + case Token::Kind::star: + parser.consumeToken(Token::Kind::star); + return Mul; + case Token::Kind::kw_floordiv: + parser.consumeToken(Token::Kind::kw_floordiv); + return FloorDiv; + case Token::Kind::kw_ceildiv: + parser.consumeToken(Token::Kind::kw_ceildiv); + return CeilDiv; + case Token::Kind::kw_mod: + parser.consumeToken(Token::Kind::kw_mod); + return Mod; + default: + return HNoOp; + } +} + +/// Parse a high precedence op expression list: mul, div, and mod are high +/// precedence binary ops, i.e., parse a +/// expr_1 op_1 expr_2 op_2 ... expr_n +/// where op_1, op_2 are all a AffineHighPrecOp (mul, div, mod). +/// All affine binary ops are left associative. +/// Given llhs, returns (llhs llhsOp lhs) op rhs, or (lhs op rhs) if llhs is +/// null. If no rhs can be found, returns (llhs llhsOp lhs) or lhs if llhs is +/// null. llhsOpLoc is the location of the llhsOp token that will be used to +/// report an error for non-conforming expressions. +AffineExpr AffineParser::parseAffineHighPrecOpExpr(AffineExpr llhs, + AffineHighPrecOp llhsOp, + SMLoc llhsOpLoc) { + AffineExpr lhs = parseAffineOperandExpr(llhs); + if (!lhs) + return nullptr; + + // Found an LHS. Parse the remaining expression. + auto opLoc = parser.curToken.getLoc(); + if (AffineHighPrecOp op = consumeIfHighPrecOp()) { + if (llhs) { + AffineExpr expr = getAffineBinaryOpExpr(llhsOp, llhs, lhs, opLoc); + if (!expr) + return nullptr; + return parseAffineHighPrecOpExpr(expr, op, opLoc); + } + // No LLHS, get RHS + return parseAffineHighPrecOpExpr(lhs, op, opLoc); + } + + // This is the last operand in this expression. + if (llhs) + return getAffineBinaryOpExpr(llhsOp, llhs, lhs, llhsOpLoc); + + // No llhs, 'lhs' itself is the expression. + return lhs; +} + +/// Parse an affine expression inside parentheses. +/// +/// affine-expr ::= `(` affine-expr `)` +AffineExpr AffineParser::parseParentheticalExpr() { + if (failed(parser.parseToken(Token::Kind::l_paren, "expected '('"))) + return nullptr; + if (parser.curToken.is(Token::Kind::r_paren)) + return (parser.emitError("no expression inside parentheses"), nullptr); + + auto expr = parseAffineExpr(); + if (!expr) + return nullptr; + if (failed(parser.parseToken(Token::Kind::r_paren, "expected ')'"))) + return nullptr; + + return expr; +} + +/// Parse the negation expression. +/// +/// affine-expr ::= `-` affine-expr +AffineExpr AffineParser::parseNegateExpression(AffineExpr lhs) { + if (failed(parser.parseToken(Token::Kind::minus, "expected '-'"))) + return nullptr; + + AffineExpr operand = parseAffineOperandExpr(lhs); + // Since negation has the highest precedence of all ops (including high + // precedence ops) but lower than parentheses, we are only going to use + // parseAffineOperandExpr instead of parseAffineExpr here. + if (!operand) + // Extra error message although parseAffineOperandExpr would have + // complained. Leads to a better diagnostic. + return (parser.emitError("missing operand of negation"), nullptr); + return (-1) * operand; +} + +/// Parse a bare id that may appear in an affine expression. +/// +/// affine-expr ::= bare-id +AffineExpr AffineParser::parseBareIdExpr() { + if (parser.curToken.isNot(Token::Kind::id)) + return (parser.emitError("expected id"), nullptr); + + StringRef sRef = parser.curToken.getSpelling(); + for (auto &list : {dims, symbols}) { + for (auto entry : list) { + if (entry.first == sRef) { + parser.consumeToken(Token::Kind::id); + return entry.second; + } + } + } + + // Not found, check fallback path. + AffineExpr expr = bareIdFallback(sRef); + if (expr) { + parser.consumeToken(Token::Kind::id); + return expr; + } + + return (parser.emitError("use of undeclared id"), nullptr); +} + +/// Parse a positive integral constant appearing in an affine expression. +/// +/// affine-expr ::= integer-literal +AffineExpr AffineParser::parseIntegerExpr() { + auto val = parser.curToken.getUInt64IntegerValue(); + if (!val.hasValue() || (int64_t)val.getValue() < 0) + return (parser.emitError("constant too large for index"), nullptr); + + parser.consumeToken(Token::Kind::integer); + return getAffineConstantExpr((int64_t)val.getValue(), parser.context); +} + +/// Parses an expression that can be a valid operand of an affine expression. +/// lhs: if non-null, lhs is an affine expression that is the lhs of a binary +/// operator, the rhs of which is being parsed. This is used to determine +/// whether an error should be emitted for a missing right operand. +// Eg: for an expression without parentheses (like i + j + k + l), each +// of the four identifiers is an operand. For i + j*k + l, j*k is not an +// operand expression, it's an op expression and will be parsed via +// parseAffineHighPrecOpExpression(). However, for i + (j*k) + -l, (j*k) and +// -l are valid operands that will be parsed by this function. +AffineExpr AffineParser::parseAffineOperandExpr(AffineExpr lhs) { + switch (parser.curToken.getKind()) { + case Token::Kind::id: + return parseBareIdExpr(); + case Token::Kind::integer: + return parseIntegerExpr(); + case Token::Kind::l_paren: + return parseParentheticalExpr(); + case Token::Kind::minus: + return parseNegateExpression(lhs); + case Token::Kind::kw_ceildiv: + case Token::Kind::kw_floordiv: + case Token::Kind::kw_mod: + case Token::Kind::plus: + case Token::Kind::star: + if (lhs) + parser.emitError("missing right operand of binary operator"); + else + parser.emitError("missing left operand of binary operator"); + return nullptr; + default: + if (lhs) + parser.emitError("missing right operand of binary operator"); + else + parser.emitError("expected affine expression"); + return nullptr; + } +} + +/// Parse affine expressions that are bare-id's, integer constants, +/// parenthetical affine expressions, and affine op expressions that are a +/// composition of those. +/// +/// All binary op's associate from left to right. +/// +/// {add, sub} have lower precedence than {mul, div, and mod}. +/// +/// Add, sub'are themselves at the same precedence level. Mul, floordiv, +/// ceildiv, and mod are at the same higher precedence level. Negation has +/// higher precedence than any binary op. +/// +/// llhs: the affine expression appearing on the left of the one being parsed. +/// This function will return ((llhs llhsOp lhs) op rhs) if llhs is non null, +/// and lhs op rhs otherwise; if there is no rhs, llhs llhsOp lhs is returned +/// if llhs is non-null; otherwise lhs is returned. This is to deal with left +/// associativity. +/// +/// Eg: when the expression is e1 + e2*e3 + e4, with e1 as llhs, this function +/// will return the affine expr equivalent of (e1 + (e2*e3)) + e4, where +/// (e2*e3) will be parsed using parseAffineHighPrecOpExpr(). +AffineExpr AffineParser::parseAffineLowPrecOpExpr(AffineExpr llhs, + AffineLowPrecOp llhsOp) { + AffineExpr lhs; + if (!(lhs = parseAffineOperandExpr(llhs))) + return nullptr; + + // Found an LHS. Deal with the ops. + if (AffineLowPrecOp lOp = consumeIfLowPrecOp()) { + if (llhs) { + AffineExpr sum = getAffineBinaryOpExpr(llhsOp, llhs, lhs); + return parseAffineLowPrecOpExpr(sum, lOp); + } + // No LLHS, get RHS and form the expression. + return parseAffineLowPrecOpExpr(lhs, lOp); + } + auto opLoc = parser.curToken.getLoc(); + if (AffineHighPrecOp hOp = consumeIfHighPrecOp()) { + // We have a higher precedence op here. Get the rhs operand for the llhs + // through parseAffineHighPrecOpExpr. + AffineExpr highRes = parseAffineHighPrecOpExpr(lhs, hOp, opLoc); + if (!highRes) + return nullptr; + + // If llhs is null, the product forms the first operand of the yet to be + // found expression. If non-null, the op to associate with llhs is llhsOp. + AffineExpr expr = + llhs ? getAffineBinaryOpExpr(llhsOp, llhs, highRes) : highRes; + + // Recurse for subsequent low prec op's after the affine high prec op + // expression. + if (AffineLowPrecOp nextOp = consumeIfLowPrecOp()) + return parseAffineLowPrecOpExpr(expr, nextOp); + return expr; + } + // Last operand in the expression list. + if (llhs) + return getAffineBinaryOpExpr(llhsOp, llhs, lhs); + // No llhs, 'lhs' itself is the expression. + return lhs; +} + +/// Parse an affine expression. +/// affine-expr ::= `(` affine-expr `)` +/// | `-` affine-expr +/// | affine-expr `+` affine-expr +/// | affine-expr `-` affine-expr +/// | affine-expr `*` affine-expr +/// | affine-expr `floordiv` affine-expr +/// | affine-expr `ceildiv` affine-expr +/// | affine-expr `mod` affine-expr +/// | bare-id +/// | integer-literal +/// +/// Additional conditions are checked depending on the production. For eg., +/// one of the operands for `*` has to be either constant/symbolic; the second +/// operand for floordiv, ceildiv, and mod has to be a positive integer. +AffineExpr AffineParser::parseAffineExpr() { + return parseAffineLowPrecOpExpr(nullptr, AffineLowPrecOp::LNoOp); +} + +SmallVector AffineParser::parseAffineExprs(Token::Kind lDelim, + Token::Kind rDelim) { + parser.parseToken(lDelim, "expected lDelim at start of affine expr list"); + + SmallVector exprs; + auto parseElt = [&]() -> LogicalResult { + auto elt = parseAffineExpr(); + exprs.push_back(elt); + return elt ? success() : failure(); + }; + + if (failed(parser.parseCommaSeparatedListUntil(rDelim, parseElt, + /*allowEmptyList=*/true))) + llvm_unreachable("Failed AffineExpr parsing"); + + return exprs; +} + +//===----------------------------------------------------------------------===// +// TC parsing. +//===----------------------------------------------------------------------===// + +namespace { + +/// Base class for expressions involved in TC parsing. +struct Expression { + enum class Kind { + Uninitialized = 0, + TensorExpr = 1, + TensorUse = 2, + }; + + explicit Expression(Kind k = Kind::Uninitialized) : kind(k) {} + virtual ~Expression() = 0; + + bool operator==(const Expression &e) const; + operator bool() const { return kind != Kind::Uninitialized; } + + Kind kind; +}; + +/// Encodes a tensor use of the form: +/// +/// affine-expr-list ::= affine-expr (`,` affine-expr)* +/// tensor-use ::= bare-id `(` `)` +/// | bare-id `(` affine-expr-list `)` +/// +/// The affine-expr-list is stored as an AffineMap. +struct TensorUse : public Expression { + TensorUse() : TensorUse("", AffineMap()) {} + TensorUse(StringRef name, AffineMap map) + : Expression(Kind::TensorUse), tensorId(name), indexingMap(map) {} + TensorUse(const TensorUse &use) = default; + + static bool classof(const Expression *e) { + return e->kind == Kind::TensorUse; + } + + bool operator==(const TensorUse &other) const { + return tensorId == other.tensorId && indexingMap == other.indexingMap; + } + + /// Visitation function. Performs preorder or postorder traversal depending on + /// `PreOrder` and applies `callback` on each node. + template + void visit(Lambda callback) const; + + StringRef tensorId; + AffineMap indexingMap; +}; + +/// Encodes a tensor expression of the form: +/// +/// op-spec ::= bare-id `<` reduction-dims-list `>` +/// | bare-id +/// op-arg ::= tensor-expr +/// | tensor-use +/// op-arg-list ::= op-arg (`,` op-arg)* +/// tensor-expr ::= op-spec `(` op-arg-list `)` +/// +/// Underlying op-arg are stored by unique_ptr to base class. +struct TensorExpr : public Expression { + TensorExpr(StringRef name, + SmallVectorImpl> &&exprs, + ArrayRef reductionDims) + : Expression(Kind::TensorExpr), opId(name), expressions(std::move(exprs)), + reductionDimensions(reductionDims.begin(), reductionDims.end()) {} + + static bool classof(const Expression *e) { + return e->kind == Kind::TensorExpr; + } + + bool operator==(const TensorExpr &other) const { + if (opId != other.opId) + return false; + if (expressions.size() != other.expressions.size()) + return false; + for (unsigned i = 0, e = expressions.size(); i < e; ++i) + if (*expressions[i] != *other.expressions[i]) + return false; + for (unsigned i = 0, e = reductionDimensions.size(); i < e; ++i) + if (reductionDimensions[i] != other.reductionDimensions[i]) + return false; + return true; + } + + /// Visitation function. Performs preorder or postorder traversal depending on + /// `PreOrder` and applies `callback` on each node. + template + void visit(Lambda callback) const; + + StringRef opId; + SmallVector, 4> expressions; + SetVector reductionDimensions; +}; + +Expression::~Expression() {} + +bool Expression::operator==(const Expression &e) const { + if (this->kind != e.kind) + return false; + if (e.kind == Expression::Kind::TensorUse) + return static_cast(*this) == + static_cast(e); + if (e.kind == Expression::Kind::TensorExpr) + return static_cast(*this) == + static_cast(e); + llvm_unreachable("Unexpected case"); +} + +/// This is a specialized parser for a TCDef. +/// This maintains the dims it finds in an eager fashion. +class TCParser { + enum class EagerDiscoveryMode { None = 0, Symbols, Dimensions }; + +public: + explicit TCParser(Parser &p); + + /// Uses the AffineParser to parse the affine exprs used in a tensor + /// definition. If `discoveryMode` is set to Symbols (resp. Dimensions), new + /// symbols (resp. dimensions) are added eagerly. Otherwise, an error is + /// emitted on new identifiers. + SmallVector + parseAffineExprs(EagerDiscoveryMode discoveryMode, AffineDimList &dims, + Token::Kind lDelim = Token::Kind::l_paren, + Token::Kind rDelim = Token::Kind::r_paren); + + /// Parse the information for a tensor def. + /// All the affine-expr must be dimensionless (i.e. contain only expressions + /// involving symbols and constants), but can otherwise contain arbitrary + /// affine expressions. + LogicalResult parseTensorDef(bool isOutput); + + /// Parses a tensor use. + struct ComprehensionParsingState { + AffineDimList dims; + SmallVector, 4> expressions; + llvm::DenseMap orderedTensorArgs; + }; + LogicalResult parseTensorUse(TensorUse &result, + ComprehensionParsingState &state); + + /// Parses a tensor expression. + LogicalResult parseExpression(TensorUse currentDefinition, + std::unique_ptr &result, + ComprehensionParsingState &state); + + /// Parse a single comprehension. + LogicalResult parseOneComprehension(StringRef cppOpName, + StringRef linalgOpName, + ComprehensionParsingState &state); + + /// Parse and print the information for a TC def. + /// When `gen-ods-decl` is used, this prints the ODS declaration for the TC. + /// When `gen-impl` is used, this prints the C++ implementation for the extra + /// methods defined in ODS (referenceIterators, referenceIndexingMaps and + /// regionBuilder). + LogicalResult parseAndEmitTCDef(llvm::raw_ostream &os); + + /// Print the ODS class that defines a new `cppOpName` for a `linalgOpName`. + void printODS(llvm::raw_ostream &os, StringRef cppOpName, + StringRef linalgOpName); + + /// Print the C++ StructuredOpsInterface impl of `referenceIterators`. + void printReferenceIterators(llvm::raw_ostream &os, StringRef opId, + ComprehensionParsingState &state); + + /// Print the C++ StructuredOpsInterface impl of `referenceIndexingMaps`. + void printReferenceIndexingMaps(llvm::raw_ostream &os, StringRef opId, + ComprehensionParsingState &state); + + /// Print the C++ StructuredOpsInterface impl of `regionBuilder`. + void printRegionBuilder(llvm::raw_ostream &os, StringRef opId, + ComprehensionParsingState &state); + +private: + //===--------------------------------------------------------------------===// + // Internal bookkeeping of tensors. + //===--------------------------------------------------------------------===// + struct RegisteredTensor { + StringRef type; + AffineMap shape; + bool isOutput; + AffineMap indexingMap; + unsigned index; + }; + + //===--------------------------------------------------------------------===// + // Per-TC def state. + //===--------------------------------------------------------------------===// + /// Symbols are per TC def. + AffineSymbolList symbols; + /// Tensors are per TC def. + llvm::StringMap registeredTensors; + unsigned nextRegisteredTensorIndex; + + Parser &parser; +}; +} // namespace + +namespace llvm { + +template <> +struct DenseMapInfo { + static TensorUse getEmptyKey() { return TensorUse("", AffineMap()); } + static TensorUse getTombstoneKey() { + return TensorUse(DenseMapInfo::getTombstoneKey(), + DenseMapInfo::getTombstoneKey()); + } + static unsigned getHashValue(const TensorUse &val) { + return ::llvm::hash_value(val.tensorId); // don't care about collisions. + } + static bool isEqual(const TensorUse &LHS, const TensorUse &RHS) { + return LHS == RHS; + } +}; + +} // namespace llvm + +//===----------------------------------------------------------------------===// +// Visitation functions. +//===----------------------------------------------------------------------===// + +template +void visit(const Expression &expr, Lambda callback) { + switch (expr.kind) { + default: + llvm_unreachable("Unexpected kind"); + case Expression::Kind::TensorExpr: + static_cast(expr).visit(callback); + break; + case Expression::Kind::TensorUse: + static_cast(expr).visit(callback); + break; + } +} + +template +void visitPreorder(const Expression &expr, Lambda callback) { + visit(expr, callback); +} + +template +void visitPostorder(Expression &expr, Lambda callback) { + visit(expr, callback); +} + +template +void TensorExpr::visit(Lambda callback) const { + if (!PreOrder) + callback(*this); + for (auto &e : expressions) + ::visit(*e, callback); + if (PreOrder) + callback(*this); +} + +template +void TensorUse::visit(Lambda callback) const { + callback(*this); +} + +//===----------------------------------------------------------------------===// +// TC parsing functions. +//===----------------------------------------------------------------------===// +TCParser::TCParser(Parser &p) + : symbols(), registeredTensors(), nextRegisteredTensorIndex(0), parser(p) {} + +/// Uses the AffineParser to parse the affine exprs used in a tensor +/// definition. All identifiers are interpreted as symbols, new symbols are +/// added eagerly. +SmallVector +TCParser::parseAffineExprs(EagerDiscoveryMode discoveryMode, + AffineDimList &dims, Token::Kind lDelim, + Token::Kind rDelim) { + AffineParser affineParser( + parser, + [&](StringRef sRef) { + AffineExpr expr; + if (discoveryMode == EagerDiscoveryMode::Symbols) { + expr = getAffineSymbolExpr(symbols.size(), parser.context); + symbols.emplace_back(sRef, expr); + } else if (discoveryMode == EagerDiscoveryMode::Dimensions) { + expr = getAffineDimExpr(dims.size(), parser.context); + dims.emplace_back(sRef, expr); + } + return expr; + }, + dims, symbols); + return affineParser.parseAffineExprs(lDelim, rDelim); +} + +/// Parse the information for a tensor def of the form: +/// +/// affine-expr-list ::= affine-expr (`,` affine-expr )* +/// tensor-typedef ::= type `(` `)` +/// | type `(` affine-expr-list `)` +/// tensor-def ::= bare-id `:` tensor-typedef +LogicalResult TCParser::parseTensorDef(bool isOutput) { + StringRef tensorId = parser.curToken.getSpelling(); + if (failed(parser.parseToken(Token::Kind::id, "expected an id")) || + failed(parser.parseToken(Token::Kind::colon, "expected colon"))) + return failure(); + + StringRef tensorType = parser.curToken.getSpelling(); + if (failed(parser.parseToken(Token::Kind::id, "expected an id"))) + return failure(); + + AffineDimList emptyDims; + auto exprs = parseAffineExprs(EagerDiscoveryMode::Symbols, emptyDims); + assert(emptyDims.empty() && "Unexpected dimension in tensor def"); + AffineMap map = + AffineMap::get(/*dimCount=*/0, symbols.size(), exprs, parser.context); + + auto iterBoolPair = registeredTensors.try_emplace( + tensorId, RegisteredTensor{tensorType, map, isOutput, AffineMap(), + nextRegisteredTensorIndex++}); + assert(iterBoolPair.second && "Could not emplace tensor registration"); + LLVM_DEBUG(llvm::dbgs() << "Recorded: " << tensorId << " " + << "with typeString: " << tensorType << " " + << "and shape: " << map << "\n"); + + return success(); +} + +/// Parses a tensor use of the form: +/// +/// affine-expr-list ::= affine-expr (`,` affine-expr)* +/// tensor-use ::= bare-id `(` `)` +/// | bare-id `(` affine-expr-list `)` +LogicalResult TCParser::parseTensorUse(TensorUse &result, + ComprehensionParsingState &state) { + StringRef tensorId = parser.curToken.getSpelling(); + if (failed(parser.parseToken(Token::Kind::id, "expected an id"))) + return failure(); + + auto exprs = parseAffineExprs(EagerDiscoveryMode::Dimensions, state.dims); + AffineMap map = + AffineMap::get(state.dims.size(), symbols.size(), exprs, parser.context); + LLVM_DEBUG(llvm::dbgs() << "Use of tensor: " << tensorId << " map: " << map + << "\n"); + + result = TensorUse(tensorId, map); + return success(); +} + +/// Parses a tensor expression of the form: +/// +/// op-spec ::= bare-id `<` reduction-dims-list `>` +/// | bare-id +/// op-arg ::= tensor-expr +/// | tensor-use +/// op-arg-list ::= op-arg (`,` op-arg)* +/// tensor-expr ::= op-spec `(` op-arg-list `)` +LogicalResult TCParser::parseExpression(TensorUse currentDefinition, + std::unique_ptr &result, + ComprehensionParsingState &state) { + StringRef opOrTensor = parser.curToken.getSpelling(); + if (registeredTensors.count(opOrTensor) > 0) { + TensorUse use; + auto res = parseTensorUse(use, state); + if (failed(res)) + return res; + result = std::make_unique(use); + return success(); + } + + if (failed(parser.parseToken(Token::Kind::id, "expected an operation"))) + return failure(); + + // This is an op. + SmallVector reductionDims; + SmallVector, 4> expressions; + + // Check if it has a reduction set, discover dimensions eagerly. + if (parser.curToken.is(Token::Kind::lt)) { + auto iters = parseAffineExprs(EagerDiscoveryMode::Dimensions, state.dims, + Token::Kind::lt, Token::Kind::gt); + for (auto iter : iters) + reductionDims.push_back(iter.cast().getPosition()); + } + + // If this op is a reduction, it's first argument is the `currentDefinition` + // tensor use. + if (!reductionDims.empty()) + expressions.push_back(std::make_unique(currentDefinition)); + LLVM_DEBUG(llvm::dbgs() << "op: " << opOrTensor << "\n"); + + auto parseExpr = [&]() -> LogicalResult { + std::unique_ptr e; + if (failed(parseExpression(currentDefinition, e, state))) + return failure(); + expressions.push_back(std::move(e)); + return success(); + }; + if (failed(parser.parseToken(Token::Kind::l_paren, "expected '('")) || + failed(parser.parseCommaSeparatedListUntil( + Token::Kind::r_paren, parseExpr, /*allowEmptyList=*/true))) + return failure(); + + result = std::make_unique(opOrTensor, std::move(expressions), + reductionDims); + + return success(); +} + +//===----------------------------------------------------------------------===// +// Parse and Emit functions. +//===----------------------------------------------------------------------===// + +/// Parse the information for a single comprehension. +/// +/// tensor-def-list ::= tensor-def (`,` tensor-def)* +/// tensor-expr-list ::= tensor-expr (`,` tensor-expr)* +/// comprehension ::= tensor-def-list `=` tensor-expr-list `;` +LogicalResult +TCParser::parseOneComprehension(StringRef cppOpName, StringRef linalgOpName, + ComprehensionParsingState &state) { + // 1. Parse LHS of `=`, these become the definitions that appear as the output + // tensors or read/write buffers. + SmallVector definitions; + auto parseUse = [&]() -> LogicalResult { + TensorUse use; + if (failed(parseTensorUse(use, state))) + return failure(); + definitions.push_back(use); + return success(); + }; + if (failed(parser.parseCommaSeparatedListUntil(Token::Kind::equal, parseUse, + /*allowEmptyList=*/true))) + return failure(); + + // 2. Parse RHS of `=`, this becomes the expressions from which we emit + // computations. + unsigned idx = 0; + auto parseExpr = [&]() -> LogicalResult { + std::unique_ptr expr; + if (idx >= definitions.size()) { + parser.emitError("Fewer LHS definitions than RHS expressions"); + return failure(); + } + if (failed(parseExpression(definitions[idx++], expr, state))) + return failure(); + state.expressions.push_back(std::move(expr)); + return success(); + }; + if (failed(parser.parseCommaSeparatedListUntil( + Token::Kind::semicolon, parseExpr, /*allowEmptyList=*/true))) + return failure(); + if (idx != definitions.size()) { + parser.emitError("Fewer RHS expressions than LHS definitions"); + return failure(); + } + + // 3. Postprocess. + // 3.a. Normalize all maps to the proper state.dims and symbols counts. + SmallVector allUses; + allUses.reserve(registeredTensors.size()); + for (auto &def : definitions) + allUses.push_back(def); + for (auto &pExpr : state.expressions) + visitPostorder(*pExpr, [&](const Expression &e) { + if (auto *use = dyn_cast(&e)) + allUses.push_back(*use); + }); + for (auto &use : allUses) + use.indexingMap = + AffineMap::get(state.dims.size(), symbols.size(), + use.indexingMap.getResults(), parser.context); + + // 3.b. Traverse definitions + llvm::DenseSet seenDefs; + for (auto &def : definitions) { + if (seenDefs.count(def.tensorId) > 0) { + parser.emitError("Unexpected multi-write to a single tensor"); + return failure(); + } + seenDefs.insert(def.tensorId); + auto tensorIter = registeredTensors.find(def.tensorId); + assert(tensorIter != registeredTensors.end() && "unregistered tensor"); + auto &tensor = tensorIter->getValue(); + tensor.indexingMap = def.indexingMap; + state.orderedTensorArgs[def] = tensor.index; + } + + bool failed = false; + for (auto &pExpr : state.expressions) + visitPostorder(*pExpr, [&](const Expression &e) { + auto *pUse = dyn_cast(&e); + if (failed || !pUse) + return; + auto &use = *pUse; + LLVM_DEBUG(llvm::dbgs() + << "\nuse: " << use.tensorId << " map: " << use.indexingMap); + auto tensorIter = registeredTensors.find(use.tensorId); + assert(tensorIter != registeredTensors.end() && "unregistered tensor"); + auto &tensor = tensorIter->getValue(); + if (tensor.indexingMap && state.orderedTensorArgs.count(use) == 0) { + LLVM_DEBUG(llvm::dbgs() << "\nexisting: " << tensor.indexingMap); + parser.emitError( + "Unexpected multi-read of a tensor with different accesses"); + failed = true; + return; + } + seenDefs.insert(use.tensorId); + tensor.indexingMap = use.indexingMap; + state.orderedTensorArgs[use] = tensor.index; + }); + if (failed) + return failure(); + + return success(); +} + +/// Parse and print the information for a TC def. +/// +/// tensor-def-list ::= tensor-def (`,` tensor-def )* +/// +/// comprehension-list ::= comprehension comprehension* +/// +/// tc-def ::= `def` bare-id `(`tensor-def-list`)` `->` `(` tensor-def-list`)` +/// `{` comprehension-list `}` +/// +/// All the affine-expr in a `tensor-typedef` must be dimensionless (i.e. +/// contain only expressions involving symbols and constants), but can +/// otherwise contain arbitrary affine expressions. +LogicalResult TCParser::parseAndEmitTCDef(llvm::raw_ostream &os) { + if (failed(parser.parseToken(Token::Kind::kw_def, + "expected 'def' to define a TC"))) + return failure(); + + StringRef tcName = parser.curToken.getSpelling(); + LLVM_DEBUG(llvm::dbgs() << "\n\nStart parsing tc: " << tcName << "\n"); + if (failed(parser.parseToken(Token::Kind::id, "expected id")) || + failed(parser.parseToken(Token::Kind::l_paren, "expected '('"))) + return failure(); + + auto parseInputDef = [&]() -> LogicalResult { + return parseTensorDef(/*isOutput=*/false); + }; + if (failed(parser.parseCommaSeparatedListUntil( + Token::Kind::r_paren, parseInputDef, /*allowEmptyList=*/false))) + return failure(); + + if (failed(parser.parseToken(Token::Kind::minus, "expected '-'")) || + failed(parser.parseToken(Token::Kind::gt, "expected '>'")) || + failed(parser.parseToken(Token::Kind::l_paren, "expected '('"))) + return failure(); + auto parseOutputDef = [&]() -> LogicalResult { + return parseTensorDef(/*isOutput=*/true); + }; + if (failed(parser.parseCommaSeparatedListUntil( + Token::Kind::r_paren, parseOutputDef, /*allowEmptyList=*/false))) + return failure(); + + // Since we don't declare symbols separately, we discover them eagerly: each + // newly encountered id in a tensor shape expression is treated as a new + // symbolic. At this point, all tensors have been parsed and all the symbols + // that could be discovered eagerly are now known. Resize all AffineMaps to + // normalize the number of eagerly discovered symbols. + for (auto &tensor : registeredTensors) { + auto &map = tensor.getValue().shape; + map = AffineMap::get(/*dimCount=*/0, symbols.size(), map.getResults(), + parser.context); + } + + if (failed(parser.parseToken(Token::Kind::l_brace, "expected '{'"))) + return failure(); + + SmallVector perComprehensionStates; + while (parser.curToken.isNot(Token::Kind::r_brace)) { + perComprehensionStates.push_back(ComprehensionParsingState()); + if (failed(parseOneComprehension(tcName, tcName, + perComprehensionStates.back()))) + return failure(); + }; + parser.parseToken(Token::Kind::r_brace, "expected '}'"); + + // Print. + auto nComprehensions = perComprehensionStates.size(); + if (nComprehensions != 1) { + parser.emitError("only 1 comprehension supported for now, got: " + + llvm::Twine(nComprehensions)); + return failure(); + } + if (genODSDecl) { + printODS(os, tcName, tcName); + os << "\n"; + } + if (genODSImpl) { + auto &state = perComprehensionStates.back(); + std::string extraMethods; + llvm::raw_string_ostream ss(extraMethods); + printReferenceIterators(ss, tcName, state); + printReferenceIndexingMaps(ss, tcName, state); + printRegionBuilder(ss, tcName, state); + ss.flush(); + os << extraMethods << "\n"; + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// Printing functions +//===----------------------------------------------------------------------===// + +/// Print the ODS class that defines a new `cppOpName` for a `linalgOpName`. +void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName, + StringRef linalgOpName) { + const char *header = R"FMT( def {0}Op : LinalgNamedStructured_Op<"{1}", [ + NInputs<{2}>, + NOutputs<{3}>, + NamedStructuredOpTraits]> { + let arguments = (ins Variadic:$views); + let results = (outs Variadic:$output_tensors); + let extraClassDeclaration = [{{ + llvm::Optional> referenceIterators(); + llvm::Optional> referenceIndexingMaps(); + void regionBuilder(ArrayRef args); + }]; + let hasFolder = 1; + })FMT"; + + unsigned nInputs = 0, nOutputs = 0; + for (auto &t : registeredTensors) { + if (t.getValue().isOutput) + nOutputs++; + else + nInputs++; + } + + os << llvm::formatv(header, cppOpName, linalgOpName, nInputs, nOutputs); +} + +/// Print the C++ StructuredOpsInterface impl of `referenceIterators`. +void TCParser::printReferenceIterators(llvm::raw_ostream &os, StringRef opId, + ComprehensionParsingState &state) { + const char *referenceReferenceIteratorsFmt = + R"FMT( + llvm::Optional> {0}::referenceIterators() { + return SmallVector{{ {1} }; + })FMT"; + + std::string iteratorsStr; + llvm::raw_string_ostream ss(iteratorsStr); + unsigned pos = 0; + interleaveComma(state.dims, ss, [&](std::pair p) { + bool reduction = false; + for (auto &expr : state.expressions) { + visitPostorder(*expr, [&](const Expression &e) { + if (auto *pTensorExpr = dyn_cast(&e)) { + if (pTensorExpr->reductionDimensions.count(pos) > 0) + reduction = true; + } + }); + if (reduction) + break; + } + ss << (reduction ? "getReductionIteratorTypeName()" + : "getParallelIteratorTypeName()"); + pos++; + }); + ss.flush(); + + os << llvm::formatv(referenceReferenceIteratorsFmt, opId, iteratorsStr); +} + +/// Print the C++ StructuredOpsInterface impl of `referenceIndexingMaps`. +void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os, StringRef opId, + ComprehensionParsingState &state) { + const char *referenceIndexingMapsFmt = + R"FMT( + llvm::Optional> {0}::referenceIndexingMaps() { + MLIRContext *context = getContext(); + AffineExpr {1}; + bindDims(context, {1}); + return SmallVector{{ {2} }; + })FMT"; + + std::string dimsStr; + llvm::raw_string_ostream ss(dimsStr); + interleaveComma(state.dims, ss, + [&](std::pair p) { ss << p.second; }); + ss.flush(); + + std::string mapsStr; + llvm::raw_string_ostream mapsStringStream(mapsStr); + SmallVector orderedUses(state.orderedTensorArgs.size()); + for (auto it : state.orderedTensorArgs) + orderedUses[it.second] = it.first; + interleaveComma(orderedUses, mapsStringStream, [&](TensorUse u) { + assert(u.indexingMap); + const char *mapFmt = "\n\tAffineMap::get({0}, 0, {1})"; + if (u.indexingMap.isEmpty()) { + mapsStringStream << llvm::formatv(mapFmt, state.dims.size(), "context"); + return; + } + + std::string exprsStr; + llvm::raw_string_ostream exprsStringStream(exprsStr); + exprsStringStream << "{"; + interleaveComma(u.indexingMap.getResults(), exprsStringStream); + exprsStringStream << "}"; + exprsStringStream.flush(); + + mapsStringStream << llvm::formatv(mapFmt, state.dims.size(), exprsStr); + }); + mapsStringStream.flush(); + + os << llvm::formatv(referenceIndexingMapsFmt, opId, dimsStr, mapsStr); +} + +/// Print the C++ StructuredOpsInterface impl of `regionBuilder`. +void TCParser::printRegionBuilder(llvm::raw_ostream &os, StringRef opId, + ComprehensionParsingState &state) { + unsigned count = state.orderedTensorArgs.size(); + llvm::DenseMap subExprsMap; + std::function printExpr; + printExpr = [&](llvm::raw_ostream &os, const Expression &e) -> void { + if (auto *pUse = dyn_cast(&e)) { + os << "_" << state.orderedTensorArgs.find(*pUse)->second; + return; + } + auto *pTensorExpr = cast(&e); + if (subExprsMap.count(pTensorExpr) > 0) { + os << "_" << subExprsMap[pTensorExpr]; + } else { + std::string subExprs; + llvm::raw_string_ostream subExprsStringStream(subExprs); + interleaveComma(pTensorExpr->expressions, subExprsStringStream, + [&](const std::unique_ptr &e) { + printExpr(subExprsStringStream, *e); + }); + subExprsStringStream.flush(); + const char *tensorExprFmt = "\n ValueHandle _{0} = {1}({2});"; + os << llvm::formatv(tensorExprFmt, ++count, pTensorExpr->opId, subExprs); + subExprsMap[pTensorExpr] = count; + } + }; + + const char *regionBuilderFmt = R"FMT( + void {0}::regionBuilder(ArrayRef args) { + using namespace edsc; + using namespace intrinsics; + ValueHandle {1}; + {2} + (linalg_yield(ValueRange{ {3} })); + })FMT"; + + unsigned idx = 0; + std::string valueHandleStr; + llvm::raw_string_ostream valueHandleStringStream(valueHandleStr); + interleaveComma(state.orderedTensorArgs, valueHandleStringStream, [&](auto) { + valueHandleStringStream << "_" << idx << "(args[" << idx << "])"; + idx++; + }); + + std::string expressionsStr; + llvm::raw_string_ostream expressionStringStream(expressionsStr); + for (auto &expr : state.expressions) + visitPostorder(*expr, [&](const Expression &e) { + if (e.kind == Expression::Kind::TensorExpr) + printExpr(expressionStringStream, e); + }); + + std::string yieldStr; + llvm::raw_string_ostream yieldStringStream(yieldStr); + interleaveComma(state.expressions, yieldStringStream, + [&](const std::unique_ptr &e) { + printExpr(yieldStringStream, *e); + }); + + valueHandleStringStream.flush(); + expressionStringStream.flush(); + yieldStringStream.flush(); + + os << llvm::formatv(regionBuilderFmt, opId, valueHandleStr, expressionsStr, + yieldStr); +} + +/// Iterate over each Tensor Comprehension def. +LogicalResult parseAndEmitAllTensorComprehensions(llvm::raw_ostream &os, + Parser &parser) { + while (parser.curToken.getKind() != Token::Kind::eof) { + TCParser tcParser(parser); + if (failed(tcParser.parseAndEmitTCDef(os))) + return failure(); + } + return success(); +} + +int main(int argc, char **argv) { + llvm::cl::ParseCommandLineOptions(argc, argv, "Linalg ODS Gen"); + + // Set up the input file. + std::string errorMessage; + std::unique_ptr file = + mlir::openInputFile(inputFilename, &errorMessage); + if (!file) { + llvm::errs() << errorMessage << "\n"; + return 1; + } + + std::unique_ptr output = + openOutputFile(outputFilename, &errorMessage); + if (!output) { + llvm::errs() << errorMessage << "\n"; + exit(1); + } + + // Include the proper Linalg header for end-to-end tblgen testing without + // resorting to non-portable shgell manipulations. + if (testEmitIncludeTdHeader) + output->os() << "include \"mlir/Dialect/Linalg/IR/LinalgStructuredOps.td\""; + + MLIRContext context; + llvm::SourceMgr mgr; + mgr.AddNewSourceBuffer(std::move(file), llvm::SMLoc()); + Parser parser(mgr, &context); + parseAndEmitAllTensorComprehensions(output->os(), parser); + output->keep(); + + return 0; +}