diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -2246,6 +2246,155 @@ return false; } +struct ArgumentSignature { + StringRef type; + StringRef name; + StringRef initializer; +}; + +static ArgumentSignature splitSignature(StringRef signature) { + signature = signature.trim(); + + // First, check if there is an initializer. We find it by looking for the + // first '=' because these are unlikely (althought not impossible) to find in + // the type signature. + size_t assignmentPos = signature.find('='); + StringRef initializer; + if (assignmentPos != std::string::npos) { + initializer = signature.drop_front(assignmentPos + 1).ltrim(); + signature = signature.take_front(assignmentPos).rtrim(); + } + + for (int64_t pos = signature.size() - 1; pos >= 0; --pos) { + // Iterate until we find a symbol that is clearly not a part of an + // identifier. + char c = signature[pos]; + if (isAlnum(c) || c == '_') + continue; + + // The part may be a name or a reserved keyword that belongs to the type, + // e.g., in "Type *const" the "const" part belongs to the type. Check if the + // identifier is a keyword or not before forming the type/name pair. + StringRef maybeName = signature.drop_front(pos + 1); + if (maybeName.trim() == "const" || maybeName.trim() == "volatile") + return {signature, StringRef(), initializer}; + return {signature.take_front(pos + 1).rtrim(), + signature.drop_front(pos + 1), initializer}; + } + + // If we only found characters that can form identifier, it's a type + // identifier rather than name. + return {signature, StringRef(), initializer}; +} + +// The builder signature is a comma-separated argument list. But it can also +// have commas inside template argument list. Filter those out. +static void splitBuilderSignature(StringRef signature, + SmallVectorImpl &arguments) { + size_t numAngleBrackets = 0; + size_t start = 0; + for (size_t pos = 0; pos < signature.size(); ++pos) { + if (signature[pos] == '<') { + ++numAngleBrackets; + } else if (signature[pos] == '>') { + assert(numAngleBrackets > 0 && "unexpected syntax error"); + --numAngleBrackets; + } else if (signature[pos] == ',' && numAngleBrackets == 0) { + arguments.push_back(signature.substr(start, pos - start)); + start = pos + 1; + } + } + arguments.push_back(signature.drop_front(start)); +} + +static std::string formatOneArg(StringRef type, Twine name, StringRef init) { + if (init.empty()) + return formatv("{0} {1}", type, name).str(); + return formatv("{0} {1} = {2}", type, name, init).str(); +} + +static void emitOneBuilderDef(const Record &def, raw_ostream &os) { + ListInit *listInit = dyn_cast_or_null(def.getValueInit("builders")); + if (!listInit) + return; + + Operator op(def); + for (Init *init : listInit->getValues()) { + Record *builderDef = cast(init)->getDef(); + StringRef params = builderDef->getValueAsString("params"); + + SmallVector args; + splitBuilderSignature(params, args); + assert(args.size() >= 2 && + "expected at least an OpBuilder and a state in builder args"); + assert(args[0].trim().startswith("OpBuilder") && + "expected first builder arg to be OpBuilder"); + assert(args[1].trim().startswith("OperationState") && + "expected second builder arg to be OperationState"); + + std::string typeNameList, nameList; + llvm::raw_string_ostream typeNameStream(typeNameList), nameStream(nameList); + + for (auto en : llvm::enumerate(llvm::makeArrayRef(args).drop_front(2))) { + ArgumentSignature signature = splitSignature(en.value()); + std::string name = signature.name.empty() + ? formatv("odsArgument{0}", en.index()).str() + : signature.name.str(); + + typeNameStream << signature.type << ' ' << name; + if (!signature.initializer.empty()) + typeNameStream << " = " << signature.initializer; + if (en.index() != args.size() - 3) + typeNameStream << ", "; + + nameStream << name; + if (en.index() != args.size() - 3) + nameStream << ", "; + } + + os << formatv(" {0} {0}({1}) {{\n", op.getCppClassName(), + typeNameStream.str()); + os << formatv(" return b.create<{0}>(loc, {1});\n", op.getCppClassName(), + nameStream.str()); + os << " }\n\n"; + } +} + +static void emitBuilderDef(ArrayRef defs, raw_ostream &os) { + if (defs.empty()) + return; + + const Dialect &dialect = Operator(defs.front()).getDialect(); + + std::string className = + formatv("{0}Builder", dialect.getCppClassName()).str(); + + os << "namespace " << dialect.getCppNamespace() << " {\n"; + os << "class " << className << " {\n"; + os << "public:\n"; + os << " " << className << "(OpBuilder &b, Location loc)"; + os << " : b(b), loc(loc) {}\n"; + + for (const Record *def : defs) { + emitOneBuilderDef(*def, os); + } + + os << "private:\n"; + os << " ::mlir::OpBuilder &b;\n"; + os << " ::mlir::Location loc;\n"; + os << "};\n"; + os << "} // end namespace " << dialect.getCppNamespace() << "\n"; +} + +static bool emitDialectBuilder(const RecordKeeper &recordKeeper, + raw_ostream &os) { + emitSourceFileHeader("DSL-style builder", os); + + // std::vector defs = ; + emitBuilderDef(getAllDerivedDefinitions(recordKeeper, "Op"), os); + return false; +} + static mlir::GenRegistration genOpDecls("gen-op-decls", "Generate op declarations", [](const RecordKeeper &records, raw_ostream &os) { @@ -2257,3 +2406,7 @@ raw_ostream &os) { return emitOpDefs(records, os); }); + +static mlir::GenRegistration genDialectBuilder("gen-dialect-builder", + "Generate DSL-style builder", + emitDialectBuilder);