diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt --- a/mlir/test/CMakeLists.txt +++ b/mlir/test/CMakeLists.txt @@ -75,6 +75,7 @@ mlir-capi-quant-test mlir-capi-sparse-tensor-test mlir-capi-transform-test + mlir-bc-tblgen mlir-linalg-ods-yaml-gen mlir-lsp-server mlir-pdll-lsp-server diff --git a/mlir/tools/CMakeLists.txt b/mlir/tools/CMakeLists.txt --- a/mlir/tools/CMakeLists.txt +++ b/mlir/tools/CMakeLists.txt @@ -8,6 +8,7 @@ add_subdirectory(mlir-translate) add_subdirectory(mlir-vulkan-runner) add_subdirectory(tblgen-lsp-server) +add_subdirectory(mlir-dialect-bytecode-bootstrap) # mlir-cpu-runner requires ExecutionEngine. if(MLIR_ENABLE_EXECUTION_ENGINE) diff --git a/mlir/tools/mlir-dialect-bytecode-bootstrap/BuiltinInc.h b/mlir/tools/mlir-dialect-bytecode-bootstrap/BuiltinInc.h new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-dialect-bytecode-bootstrap/BuiltinInc.h @@ -0,0 +1,20 @@ +static LogicalResult readAPIntWithKnownWidth(DialectBytecodeReader &reader, + unsigned bitWidth, APInt &val) { + FailureOr value = reader.readAPIntWithKnownWidth(bitWidth); + if (failed(value)) + return failure(); + val = *value; + return success(); +} + +// Returns the bitwidth if known, else return 0. +static unsigned getIntegerBitWidth(DialectBytecodeReader &reader, Type type) { + if (auto intType = dyn_cast(type)) { + return intType.getWidth(); + } else if (type.isa()) { + return IndexType::kInternalStorageBitWidth; + } + reader.emitError() + << "expected integer or index type for IntegerAttr, but got: " << type; + return 0; +} diff --git a/mlir/tools/mlir-dialect-bytecode-bootstrap/BytecodeAttrType.td b/mlir/tools/mlir-dialect-bytecode-bootstrap/BytecodeAttrType.td new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-dialect-bytecode-bootstrap/BytecodeAttrType.td @@ -0,0 +1,517 @@ +// Bytecode base classes/defs. +// Helper classes/defs to help build a reader/writer. + +class Bytecode { + // Template for parsing. + // {0} == dialect bytecode reader + // {1} == result type of parsed instance + // {2} == variable being parsed + // If parser is not specified, then the parse of members is used. + string cParser = parse; + + // Template for building from parsed. + // {0} == result type of parsed instance + // {1} == args/members comma separated + string cBuilder = build; + + // Template for printing. + // {0} == dialect bytecode writer + // {1} == parent attribute/type name + // {2} == getter + string cPrinter = print; + + // Template for getter from in memory form. + // {0} == attribute/type + // {1} == member instance + // {2} == get + UpperCamelFromSnake({1}) + string cGetter = get; + + // Type built. + // Note: if cType is empty, then name of def is used. + string cType = t; + + // Predicate guarding parse method as an Attribute/Type could have multiple + // parse methods, specify predicates to be orthogonal and cover entire + // "print space" to avoid order dependence. + // If empty then method is unconditional. + // {0} == predicate function to apply on value dyn_casted to cType. + string printerPredicate = ""; +} + +class WithParser> : + Bytecode; +class WithBuilder> : + Bytecode; +class WithPrinter> : + Bytecode; +class WithType> : + Bytecode; +class WithGetter> : + Bytecode; + +class CompositeBytecode : WithType; + +class AttributeKind : + WithParser <"succeeded({0}.readAttribute({2}))", + WithBuilder<"{1}", + WithPrinter<"{0}.writeAttribute({2})">>>; +def Attribute : AttributeKind; +class TypeKind : + WithParser <"succeeded({0}.readType({2}))", + WithBuilder<"{1}", + WithPrinter<"{0}.writeType({2})">>>; +def Type : TypeKind; +def VarInt : + WithParser <"succeeded({0}.readVarInt({2}))", + WithBuilder<"{1}", + WithPrinter<"{0}.writeVarInt({2})", + WithType <"uint64_t">>>>; + +class KnownWidthAPInt : + WithParser <"succeeded(readAPIntWithKnownWidth({0}, " # s # ", {2}))", + WithBuilder<"{1}", + WithPrinter<"{0}.writeAPIntWithKnownWidth({2})", + WithType <"APInt">>>>; + +// Helper to define variable that is defined later but not parsed nor printed. +class LocalVar : + WithParser <"(({2} = " # d # "), true)", + WithBuilder<"{1}", + WithPrinter<"", + WithType >>>; + +// Array instances. +class Array { + Bytecode elemT = t; + + string cBuilder = "{1}"; +} + +// Define dialect attribute or type. +class DialectAttrOrType { + int enum = e; + // Any members starting with underscore is not fed to create function but + // treated as purely local variable. + dag members = d; + string dialect = di; + + // When needing to specify a custom return type. + string cType = ""; + + // Any post-processing that needs to be done. + code postProcess = ""; +} + +class DialectAttribute : DialectAttrOrType, + AttributeKind { + let cParser = "succeeded({0}.readAttribute<{1}>({2}))"; + let cBuilder = "{0}::get({1})"; +} +class DialectType : DialectAttrOrType, + TypeKind { + let cParser = "succeeded({0}.readAttribute<{1}>({2}))"; + let cBuilder = "{0}::get({1})"; +} + +def Context : + WithParser <"({2}=context)", + WithBuilder<"{1}", + WithPrinter<"", + WithType <"MLIRContext*">>>>; + +def attr; +def type; + +// ---------------- +// Builtin dialect. +// ================ + +class BuiltinDialectAttribute : DialectAttribute<"Builtin", e, d>; +class BuiltinDialectType : DialectType<"Builtin", e, d>; + +def LocationAttr : AttributeKind; + +def Location : CompositeBytecode { + dag members = (attr + WithGetter<"(LocationAttr){0}", WithType<"LocationAttr", LocationAttr>>:$value + ); + let cBuilder = "Location({1})"; +} + +def String : + WithParser <"succeeded({0}.readString({2}))", + WithBuilder<"{1}", + WithPrinter<"{0}.writeOwnedString({2})", + WithType <"StringRef">>>>; + +// enum AttributeCode { +// /// ArrayAttr { +// /// elements: Attribute[] +// /// } +// /// +// kArrayAttr = 0, +// + +let cType = "StringAttr" in { +// /// StringAttr { +// /// value: string +// /// } +// kStringAttr = 2, +def StringAttr : BuiltinDialectAttribute { + let printerPredicate = "{0}.getType().isa()"; + let cBuilder = "StringAttr::get(context, {1})"; +} + +// /// StringAttrWithType { +// /// value: string, +// /// type: Type +// /// } +// /// A variant of StringAttr with a type. +// kStringAttrWithType = 3, +def StringAttrWithType : BuiltinDialectAttribute { let printerPredicate = "!{0}.getType().isa()"; } +} + +// /// DictionaryAttr { +// /// attrs: [] +// /// } +// kDictionaryAttr = 1, +def NamedAttribute : CompositeBytecode { + dag members = (attr + StringAttr:$name, + Attribute:$value + ); + let cBuilder = "NamedAttribute({1})"; +} +def DictionaryAttr : BuiltinDialectAttribute:$value +)>; + +// /// FlatSymbolRefAttr { +// /// rootReference: StringAttr +// /// } +// /// A variant of SymbolRefAttr with no leaf references. +// kFlatSymbolRefAttr = 4, +def FlatSymbolRefAttr: BuiltinDialectAttribute; + +// /// SymbolRefAttr { +// /// rootReference: StringAttr, +// /// leafReferences: FlatSymbolRefAttr[] +// /// } +// kSymbolRefAttr = 5, +def SymbolRefAttr: BuiltinDialectAttribute:$nestedReferences +)>; + +// /// TypeAttr { +// /// value: Type +// /// } +// kTypeAttr = 6, +def TypeAttr: BuiltinDialectAttribute; + +// /// UnitAttr { +// /// } +// kUnitAttr = 7, +def UnitAttr: BuiltinDialectAttribute; + +// /// IntegerAttr { +// /// type: Type +// /// value: APInt, +// /// } +// kIntegerAttr = 8, +def IntegerAttr: BuiltinDialectAttribute:$_width, + KnownWidthAPInt<"_width">:$value +)>; + +// +// /// FloatAttr { +// /// type: FloatType +// /// value: APFloat +// /// } +// kFloatAttr = 9, +// +// /// CallSiteLoc { +// /// callee: LocationAttr, +// /// caller: LocationAttr +// /// } +// kCallSiteLoc = 10, +def CallSiteLoc : BuiltinDialectAttribute; + +// /// FileLineColLoc { +// /// filename: StringAttr, +// /// line: varint, +// /// column: varint +// /// } +// kFileLineColLoc = 11, +def FileLineColLoc : BuiltinDialectAttribute; + +let cType = "FusedLoc" in { +// /// FusedLoc { +// /// locations: Location[] +// /// } +// kFusedLoc = 12, +def FusedLoc : BuiltinDialectAttribute:$locations +)> { + let printerPredicate = "!{0}.getMetadata()"; + let cBuilder = "cast(FusedLoc::get(context, {1}))"; +} + +// /// FusedLocWithMetadata { +// /// locations: LocationAttr[], +// /// metadata: Attribute +// /// } +// /// A variant of FusedLoc with metadata. +// kFusedLocWithMetadata = 13, +def FusedLocWithMetadata : BuiltinDialectAttribute:$locations, + Attribute:$metadata +)> { + let printerPredicate = "{0}.getMetadata()"; + let cBuilder = "cast(FusedLoc::get(context, {1}))"; +} +} + +// /// NameLoc { +// /// name: StringAttr, +// /// childLoc: LocationAttr +// /// } +// kNameLoc = 14, +def NameLoc : BuiltinDialectAttribute; + +// /// UnknownLoc { +// /// } +// kUnknownLoc = 15, +def UnknownLoc : BuiltinDialectAttribute; + +// /// DenseResourceElementsAttr { +// /// type: Type, +// /// handle: ResourceHandle +// /// } +// kDenseResourceElementsAttr = 16, +// +// /// DenseArrayAttr { +// /// type: RankedTensorType, +// /// data: blob +// /// } +// kDenseArrayAttr = 17, +// +// /// DenseIntOrFPElementsAttr { +// /// type: ShapedType, +// /// data: blob +// /// } +// kDenseIntOrFPElementsAttr = 18, +// +// /// DenseStringElementsAttr { +// /// type: ShapedType, +// /// isSplat: varint, +// /// data: string[] +// /// } +// kDenseStringElementsAttr = 19, +// +// /// SparseElementsAttr { +// /// type: ShapedType, +// /// indices: DenseIntElementsAttr, +// /// values: DenseElementsAttr +// /// } +// kSparseElementsAttr = 20, + +// Types +// ----- + +// enum TypeCode { +// /// IntegerType { +// /// widthAndSignedness: varint // (width << 2) | (signedness) +// /// } +// /// +// kIntegerType = 0, +def IntegerType : BuiltinDialectType>>>:$_widthAndSignedness, + // Split up parsed varint for create method. + LocalVar<"uint64_t", "_widthAndSignedness >> 2">:$width, + LocalVar<"IntegerType::SignednessSemantics", + "static_cast(_widthAndSignedness & 0x3)">:$signedness +)>; + +// +// /// IndexType { +// /// } +// /// +// kIndexType = 1, +def IndexType : BuiltinDialectType; + +// /// FunctionType { +// /// inputs: Type[], +// /// results: Type[] +// /// } +// /// +// kFunctionType = 2, +def FunctionType : BuiltinDialectType:$inputs, + Array:$results +)>; + +// /// BFloat16Type { +// /// } +// /// +// kBFloat16Type = 3, +def BFloat16Type : BuiltinDialectType; + +// /// Float16Type { +// /// } +// /// +// kFloat16Type = 4, +def Float16Type : BuiltinDialectType; + +// /// Float32Type { +// /// } +// /// +// kFloat32Type = 5, +def Float32Type : BuiltinDialectType; + +// /// Float64Type { +// /// } +// /// +// kFloat64Type = 6, +def Float64Type : BuiltinDialectType; + +// /// Float80Type { +// /// } +// /// +// kFloat80Type = 7, +def Float80Type : BuiltinDialectType; + +// /// Float128Type { +// /// } +// /// +// kFloat128Type = 8, +def Float128Type : BuiltinDialectType; + +// /// ComplexType { +// /// elementType: Type +// /// } +// /// +// kComplexType = 9, +def ComplexType : BuiltinDialectType; + +// /// MemRefType { +// /// shape: svarint[], +// /// elementType: Type, +// /// layout: Attribute +// /// } +// /// +// kMemRefType = 10, +// +// /// MemRefTypeWithMemSpace { +// /// memorySpace: Attribute, +// /// shape: svarint[], +// /// elementType: Type, +// /// layout: Attribute +// /// } +// /// Variant of MemRefType with non-default memory space. +// kMemRefTypeWithMemSpace = 11, +// +// /// NoneType { +// /// } +// /// +// kNoneType = 12, +def NoneType : BuiltinDialectType; + +// /// RankedTensorType { +// /// shape: svarint[], +// /// elementType: Type, +// /// } +// /// +// kRankedTensorType = 13, +// +// /// RankedTensorTypeWithEncoding { +// /// encoding: Attribute, +// /// shape: svarint[], +// /// elementType: Type +// /// } +// /// Variant of RankedTensorType with an encoding. +// kRankedTensorTypeWithEncoding = 14, +// +// /// TupleType { +// /// elementTypes: Type[] +// /// } +// kTupleType = 15, +def TupleType : BuiltinDialectType:$types +)>; + +// /// UnrankedMemRefType { +// /// shape: svarint[] +// /// } +// /// +// kUnrankedMemRefType = 16, +// +// /// UnrankedMemRefTypeWithMemSpace { +// /// memorySpace: Attribute, +// /// shape: svarint[] +// /// } +// /// Variant of UnrankedMemRefType with non-default memory space. +// kUnrankedMemRefTypeWithMemSpace = 17, +// +// /// UnrankedTensorType { +// /// elementType: Type +// /// } +// /// +// kUnrankedTensorType = 18, +def UnrankedTensorType : BuiltinDialectType; + +// /// VectorType { +// /// shape: svarint[], +// /// elementType: Type +// /// } +// /// +// kVectorType = 19, +// +// /// VectorTypeWithScalableDims { +// /// numScalableDims: varint, +// /// shape: svarint[], +// /// elementType: Type +// /// } +// /// Variant of VectorType with scalable dimensions. +// kVectorTypeWithScalableDims = 20, diff --git a/mlir/tools/mlir-dialect-bytecode-bootstrap/CMakeLists.txt b/mlir/tools/mlir-dialect-bytecode-bootstrap/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-dialect-bytecode-bootstrap/CMakeLists.txt @@ -0,0 +1,16 @@ +set(LLVM_OPTIONAL_SOURCES + BuiltinInc.cpp +) + +add_executable(mlir-bc-tblgen + mlir-bc-tblgen.cpp + ) +set_target_properties(mlir-bc-tblgen PROPERTIES FOLDER "Tablegenning") +target_link_libraries(mlir-bc-tblgen + PRIVATE + LLVMDemangle + LLVMSupport + LLVMTableGen + MLIRSupportIndentedOstream + ) +llvm_update_compile_flags(mlir-bc-tblgen) diff --git a/mlir/tools/mlir-dialect-bytecode-bootstrap/README.md b/mlir/tools/mlir-dialect-bytecode-bootstrap/README.md new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-dialect-bytecode-bootstrap/README.md @@ -0,0 +1,15 @@ +# Dialect bytecode parser bootstrap + +Simple tool to help in bootstrapping the dialect bytecode parsing definitions. +It is not meant as a full "spec" but rather avoids writing boilerplate. + +It is meant to be retargetable to read/write into other forms and so most +specialization should happen TableGen side. This is not there yet, currently +there are hardcoded behavior C++ side that will be removed, but was able to +switch this between different formats in roughly a day. + +This is a separate binary rather than linked in to mlir-tblgen as this is more a +simple helper tool that can be copied/modified/retargeted than a core feature. + +This tool is not meant to be run as part of an automated build script. + diff --git a/mlir/tools/mlir-dialect-bytecode-bootstrap/mlir-bc-tblgen.cpp b/mlir/tools/mlir-dialect-bytecode-bootstrap/mlir-bc-tblgen.cpp new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-dialect-bytecode-bootstrap/mlir-bc-tblgen.cpp @@ -0,0 +1,483 @@ +//===- mlir-bc-tblgen.cpp - TableGen helper for MLIR bytecode -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include +#include + +#include "mlir/Support/IndentedOstream.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/InitLLVM.h" +#include "llvm/Support/Signals.h" +#include "llvm/Support/ToolOutputFile.h" +#include "llvm/TableGen/Error.h" +#include "llvm/TableGen/Main.h" +#include "llvm/TableGen/Record.h" + +using namespace llvm; + +cl::opt selectedDialect("dialect", + llvm::cl::desc("The dialect to gen for")); + +const char *progName; + +static int reportError(Twine Msg) { + errs() << progName << ": " << Msg; + errs().flush(); + return 1; +} + +// Helper class to generate C++ bytecode parser helpers. +class Generator { +public: + Generator(StringRef dialectName) : dialectName(dialectName) {} + + // Returns whether successfully created output dirs/files. + void init(const std::filesystem::path &mainFileRoot); + + // Returns whether successfully terminated output files. + void fin(raw_ostream &os); + + // Returns whether successfully emitted attribute/type parsers. + void emitParse(StringRef kind, Record &attr); + + // Returns whether successfully emitted attribute/type printers. + void emitPrint(StringRef kind, StringRef type, ArrayRef vec); + + // Emits parse dispatch table. + void emitParseDispatch(StringRef kind, ArrayRef vec); + + // Emits print dispatch table. + void emitPrintDispatch(StringRef kind, ArrayRef vec); + +private: + // Emits parse calls to construct givin kind. + void emitParseHelper(StringRef kind, StringRef returnType, StringRef builder, + ArrayRef args, ArrayRef argNames, + StringRef failure, mlir::raw_indented_ostream &ios); + + // Emits print instructions. + void emitPrintHelper(Record *memberRec, StringRef kind, StringRef parent, + StringRef name, mlir::raw_indented_ostream &ios); + + StringRef dialectName; + std::string topStr, bottomStr; +}; + +void Generator::init(const std::filesystem::path &mainFileRoot) { + // Generate top section. + raw_string_ostream os(topStr); + + // Inject additional header include file. + auto incHeaderFileName = + mainFileRoot / formatv("{0}Inc.h", dialectName).str(); + if (std::filesystem::exists(incHeaderFileName)) { + auto incFileOr = + MemoryBuffer::getFile(incHeaderFileName.string(), /*IsText=*/true, + /*RequiresNullTerminator=*/false); + if (incFileOr) { + os << incFileOr->get()->getBuffer() << "\n"; + } else { + PrintFatalError("FAILURE: " + incFileOr.getError().message()); + } + } +} + +void Generator::fin(raw_ostream &os) { + os << topStr << "\n" << StringRef(bottomStr).rtrim() << "\n"; +} + +static std::string capitalize(StringRef str) { + return ((Twine)toUpper(str[0]) + str.drop_front()).str(); +} + +std::string getCType(Record *def) { + std::string format = "{0}"; + if (def->isSubClassOf("Array")) { + def = def->getValueAsDef("elemT"); + format = "SmallVector<{0}>"; + } + + StringRef cType = def->getValueAsString("cType"); + if (cType.empty()) { + if (def->isAnonymous()) { + PrintFatalError(def->getLoc(), "Unable to determine cType"); + } + + return formatv(format.c_str(), def->getName().str()); + } + return formatv(format.c_str(), cType.str()); +} + +void Generator::emitParseDispatch(StringRef kind, ArrayRef vec) { + raw_string_ostream sos(bottomStr); + mlir::raw_indented_ostream os(sos); + char const *head = + R"(static {0} read{0}(MLIRContext* context, DialectBytecodeReader &reader))"; + os << formatv(head, capitalize(kind)); + auto funScope = os.scope(" {\n", "}\n\n"); + + os << "uint64_t kind;\n"; + os << "if (failed(reader.readVarInt(kind)))\n" + << " return " << capitalize(kind) << "();\n"; + os << "switch (kind) "; + { + auto switchScope = os.scope("{\n", "}\n"); + for (Record *it : vec) { + os << formatv("case /* {0} */ {1}:\n return read{0}(context, reader);\n", + it->getName(), it->getValueAsInt("enum")); + } + os << "default:\n" + << " reader.emitError() << \"unknown builtin attribute code: \" " + << "<< kind;\n" + << " return " << capitalize(kind) << "();\n"; + } + os << "return " << capitalize(kind) << "();\n"; +} + +void Generator::emitParse(StringRef kind, Record &attr) { + char const *head = + R"(static {0} read{1}(MLIRContext* context, DialectBytecodeReader &reader) )"; + raw_string_ostream os(bottomStr); + mlir::raw_indented_ostream ios(os); + std::string returnType = getCType(&attr); + ios << formatv(head, returnType, attr.getName()); + DagInit *members = attr.getValueAsDag("members"); + SmallVector argNames = + llvm::to_vector(map_range(members->getArgNames(), [](StringInit *init) { + return init->getAsUnquotedString(); + })); + StringRef builder = attr.getValueAsString("cBuilder"); + emitParseHelper(kind, returnType, builder, members->getArgs(), argNames, + returnType + "()", ios); +} + +void Generator::emitParseHelper(StringRef kind, StringRef returnType, + StringRef builder, ArrayRef args, + ArrayRef argNames, + StringRef failure, + mlir::raw_indented_ostream &ios) { + auto funScope = ios.scope("{\n", "}\n\n"); + + if (args.empty()) { + ios << formatv("return {0}::get(context);\n", returnType); + return; + } + + // Print decls. + std::string lastCType = ""; + for (auto [arg, name] : zip(args, argNames)) { + DefInit *first = dyn_cast(arg); + if (!first) + PrintFatalError("Unexpected type for " + name); + Record *def = first->getDef(); + + std::string cType = getCType(def); + if (lastCType == cType) { + ios << ", "; + } else { + if (!lastCType.empty()) + ios << ";\n"; + ios << cType << " "; + } + ios << name; + lastCType = cType; + } + ios << ";\n"; + + auto listHelperName = [](StringRef name) { + return formatv("read{0}", capitalize(name)); + }; + + // Emit list helper functions. + for (auto [arg, name] : zip(args, argNames)) { + Record *attr = cast(arg)->getDef(); + if (!attr->isSubClassOf("Array")) + continue; + + // TODO: Dedupe readers. + Record *def = attr->getValueAsDef("elemT"); + if (!def->isSubClassOf("CompositeBytecode") && + (def->isSubClassOf("AttributeKind") || def->isSubClassOf("TypeKind"))) + continue; + + std::string returnType = getCType(def); + ios << "auto " << listHelperName(name) << " = [&]() -> FailureOr<" + << returnType << "> "; + SmallVector args; + SmallVector argNames; + if (def->isSubClassOf("CompositeBytecode")) { + DagInit *members = def->getValueAsDag("members"); + args = llvm::to_vector(members->getArgs()); + argNames = llvm::to_vector( + map_range(members->getArgNames(), [](StringInit *init) { + return init->getAsUnquotedString(); + })); + } else { + args = {def->getDefInit()}; + argNames = {"temp"}; + } + StringRef builder = def->getValueAsString("cBuilder"); + emitParseHelper(kind, returnType, builder, args, argNames, "failure()", + ios); + ios << ";"; + } + + // Print parse conditional. + { + ios << "if "; + auto parenScope = ios.scope("(", ") {"); + ios.indent(); + + auto parsedArgs = + llvm::to_vector(make_filter_range(args, [](Init *const attr) { + Record *def = cast(attr)->getDef(); + if (def->isSubClassOf("Array")) + return true; + return !def->getValueAsString("cParser").empty(); + })); + + interleave( + zip(parsedArgs, argNames), + [&](std::tuple it) { + Record *attr = cast(std::get<0>(it))->getDef(); + std::string parser; + if (auto optParser = attr->getValueAsOptionalString("cParser")) { + parser = *optParser; + } else if (attr->isSubClassOf("Array")) { + Record *def = attr->getValueAsDef("elemT"); + bool composite = def->isSubClassOf("CompositeBytecode"); + if (!composite && def->isSubClassOf("AttributeKind")) + parser = "succeeded({0}.readAttributes({2}))"; + else if (!composite && def->isSubClassOf("TypeKind")) + parser = "succeeded({0}.readTypes({2}))"; + else + parser = ("succeeded({0}.readList({2}, " + + listHelperName(std::get<1>(it)) + "))") + .str(); + } else { + PrintFatalError(attr->getLoc(), "No parser specified"); + } + std::string type = getCType(attr); + ios << formatv(parser.c_str(), "reader", type, std::get<1>(it)); + }, + [&]() { ios << " &&\n"; }); + } + + // Compute args to pass to create method. + auto passedArgs = llvm::to_vector(make_filter_range( + argNames, [](StringRef str) { return !str.starts_with("_"); })); + std::string argStr; + raw_string_ostream argStream(argStr); + interleaveComma(passedArgs, argStream, + [&](const std::string &str) { argStream << str; }); + // Return the invoked constructor. + ios << "\nreturn " + << formatv(builder.str().c_str(), returnType, argStream.str()) << ";\n"; + ios.unindent(); + + // TODO: Emit error in debug. + // This assumes the result types in error case can always be empty + // constructed. ios << "}\nreturn mlirBytecodeEmitError(\"invalid " << + // attr.getName() + // << "\");\n"; + ios << "}\nreturn " << failure << ";\n"; +} + +void Generator::emitPrint(StringRef kind, StringRef type, + ArrayRef vec) { + char const *head = + R"(static void write({0} {1}, DialectBytecodeWriter &writer) )"; + raw_string_ostream os(bottomStr); + mlir::raw_indented_ostream ios(os); + ios << formatv(head, type, kind); + auto funScope = ios.scope("{\n", "}\n\n"); + + // Check that predicates specified if multiple bytecode instances. + for (Record *rec : vec) { + StringRef pred = rec->getValueAsString("printerPredicate"); + if (vec.size() > 1 && pred.empty()) { + for (Record *rec : vec) { + StringRef pred = rec->getValueAsString("printerPredicate"); + if (vec.size() > 1 && pred.empty()) + PrintError(rec->getLoc(), + "Requires parsing predicate given common cType"); + } + PrintFatalError("Unspecified for shared cType " + type); + } + } + + for (Record *rec : vec) { + StringRef pred = rec->getValueAsString("printerPredicate"); + if (!pred.empty()) { + ios << "if (" << formatv(pred.str().c_str(), kind) << ") {\n"; + ios.indent(); + } + + ios << "writer.writeVarInt(/* " << rec->getName() << " */ " + << rec->getValueAsInt("enum") << ");\n"; + + DagInit *members = rec->getValueAsDag("members"); + for (auto [arg, name] : + llvm::zip(members->getArgs(), members->getArgNames())) { + DefInit *def = dyn_cast(arg); + assert(def); + Record *memberRec = def->getDef(); + emitPrintHelper(memberRec, kind, kind, name->getAsUnquotedString(), ios); + } + + if (!pred.empty()) { + ios.unindent(); + ios << "}\n"; + } + } +} + +void Generator::emitPrintHelper(Record *memberRec, StringRef kind, + StringRef parent, StringRef name, + mlir::raw_indented_ostream &ios) { + std::string getter; + if (auto cGetter = memberRec->getValueAsOptionalString("cGetter"); + cGetter && !cGetter->empty()) { + getter = formatv(cGetter->str().c_str(), parent, + "get" + convertToCamelFromSnakeCase(name, true)); + } else { + getter = + formatv("{0}.get{1}()", parent, convertToCamelFromSnakeCase(name, true)) + .str(); + } + + if (memberRec->isSubClassOf("Array")) { + Record *def = memberRec->getValueAsDef("elemT"); + if (!def->isSubClassOf("CompositeBytecode")) { + if (def->isSubClassOf("AttributeKind")) { + ios << "writer.writeAttributes(" << getter << ");\n"; + return; + } + if (def->isSubClassOf("TypeKind")) { + ios << "writer.writeTypes(" << getter << ");\n"; + return; + } + } + std::string returnType = getCType(def); + ios << "writer.writeList(" << getter << ", [&](" << returnType << " " + << kind << ") "; + auto lambdaScope = ios.scope("{\n", "});\n"); + return emitPrintHelper(def, kind, kind, kind, ios); + } + if (memberRec->isSubClassOf("CompositeBytecode")) { + DagInit *members = memberRec->getValueAsDag("members"); + for (auto [arg, argName] : + zip(members->getArgs(), members->getArgNames())) { + DefInit *def = dyn_cast(arg); + assert(def); + emitPrintHelper(def->getDef(), kind, parent, + argName->getAsUnquotedString(), ios); + } + } + + if (std::string printer = memberRec->getValueAsString("cPrinter").str(); + !printer.empty()) + ios << formatv(printer.c_str(), "writer", kind, getter) << ";\n"; +} + +void Generator::emitPrintDispatch(StringRef kind, ArrayRef vec) { + raw_string_ostream sos(bottomStr); + mlir::raw_indented_ostream os(sos); + char const *head = R"(static LogicalResult write{0}({0} {1}, + DialectBytecodeWriter &writer))"; + os << formatv(head, capitalize(kind), kind); + auto funScope = os.scope(" {\n", "}\n\n"); + + os << "return TypeSwitch<" << capitalize(kind) << ", LogicalResult>(" << kind + << ")"; + auto switchScope = os.scope("", ""); + for (StringRef type : vec) { + os << "\n.Case([&](" << type << " t)"; + auto caseScope = os.scope(" {\n", "})"); + os << "return write(t, writer), success();\n"; + } + os << "\n.Default([&](" << capitalize(kind) << ") { return failure(); });\n"; +} + +struct AttrOrType { + std::vector attr, type; +}; + +static bool tableGenMain(raw_ostream &os, RecordKeeper &records) { + MapVector dialectAttrOrType; + Record *attr = records.getClass("DialectAttribute"); + Record *type = records.getClass("DialectType"); + for (auto &it : records.getAllDerivedDefinitions("DialectAttrOrType")) { + if (!selectedDialect.empty() && + it->getValueAsString("dialect") != selectedDialect) + continue; + + if (it->isSubClassOf(attr)) { + dialectAttrOrType[it->getValueAsString("dialect")].attr.push_back(it); + } else if (it->isSubClassOf(type)) { + dialectAttrOrType[it->getValueAsString("dialect")].type.push_back(it); + } + } + + if (dialectAttrOrType.size() != 1) + return reportError("Single dialect per invocation required (either only " + "one in input file or specified via dialect option)"); + + // Compare two records by enum value. + auto compEnum = [](Record *lhs, Record *rhs) -> int { + return lhs->getValueAsInt("enum") < rhs->getValueAsInt("enum"); + }; + + auto mainFile = + std::filesystem::path(records.getInputFilename()).remove_filename(); + + auto it = dialectAttrOrType.front(); + Generator gen(it.first); + gen.init(mainFile); + + SmallVector *, 2> vecs; + SmallVector kinds; + vecs.push_back(&it.second.attr); + kinds.push_back("attribute"); + vecs.push_back(&it.second.type); + kinds.push_back("type"); + for (auto [vec, kind] : zip(vecs, kinds)) { + // Handle Attribute emission. + std::sort(vec->begin(), vec->end(), compEnum); + std::map> perType; + for (auto *kt : *vec) + perType[getCType(kt)].push_back(kt); + for (const auto &jt : perType) { + for (auto *kt : jt.second) + gen.emitParse(kind, *kt); + gen.emitPrint(kind, jt.first, jt.second); + } + gen.emitParseDispatch(kind, *vec); + + SmallVector types; + for (const auto &it : perType) { + types.push_back(it.first); + } + gen.emitPrintDispatch(kind, types); + } + gen.fin(os); + + return false; +} + +int main(int argc, char **argv) { + llvm::InitLLVM y(argc, argv); + + cl::ParseCommandLineOptions(argc, argv); + progName = argv[0]; + + return TableGenMain(argv[0], &tableGenMain); +}