diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp @@ -0,0 +1,262 @@ +//===- OpPythonBindingGen.cpp - Generator of Python API for MLIR Ops ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// OpPythonBindingGen uses ODS specification of MLIR ops to generate Python +// binding classes wrapping a generic operation API. +// +//===----------------------------------------------------------------------===// + +#include "mlir/TableGen/GenInfo.h" +#include "mlir/TableGen/Operator.h" +#include "llvm/ADT/StringSet.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/TableGen/Error.h" +#include "llvm/TableGen/Record.h" + +using namespace mlir; +using namespace mlir::tblgen; + +constexpr const char *fileHeader = R"Py( +# Autogenerated by mlir-tblgen; don't manually edit. + +from . import _cext +_ir = _cext.ir +)Py"; + +constexpr const char *segmentedListAccessorTemplate = R"Py( +def _get_segmented_{0}_range(operation, idx): + raw_segments = operation.attributes["{0}_segment_sizes"] + segments = _ir.DenseIntElementsAttr(raw_segments) + start = sum(segments[i] for i in range(idx)) + end = start + segments[idx] + return operation.{0}s[start:end] + +)Py"; + +constexpr const char *equallySizedAccessorTemplate = R"Py( +def _get_equally_sized_{0}_range(operation, n_variadic, n_preceding_simple, + n_preceding_variadic) + total_variadic_length = len(operation.{0}s) - n_variadic + 1 + if total_variadic_length % n_variadic != 0 + raise RuntimeError("unexpected mismatch in number of {0}s " + "for equally-sized variadic groups") + elements_per_group = total_variadic_length // n_variadic + start = n_preceding_simple + n_preceding_variadic * elements_per_group + return start, elements_per_group +)Py"; + +constexpr const char *dialectClassTemplate = R"Py( +@_cext.register_dialect +class _Dialect(_ir.Dialect): + DIALECT_NAMESPACE = "{0}" + pass + +)Py"; + +constexpr const char *opClassTemplate = R"Py( +@_cext.register_operation(_Dialect) +class {0}(_ir.OpView): + OPERATION_NAME = "{1}" +)Py"; + +constexpr const char *opSingleTemplate = R"Py( + @property + def {0}(self): + return self.operation.{1}s[{2}] +)Py"; + +constexpr const char *opOneOptionalTemplate = R"Py( + @property + def {0}(self); + return self.operation.{1}s[{3}] if len(self.operation.{1}s) > {2} + else None +)Py"; + +constexpr const char *opOneVariadicTemplate = R"Py( + @property + def {0}(self): + variadic_group_length = len(self.operation.{1}s) - {2} + 1 + return self.operation.{1}s[{3}:{3} + variadic_group_length] +)Py"; + +constexpr const char *opVariadicEqualPrefixTemplate = R"Py( + @property + def {0}(self): + start, pg = _get_equally_sized_{1}_range(operation, {2}, {3}, {4}))Py"; + +constexpr const char *opVariadicEqualSimpleTemplate = R"Py( + return self.operation.{0}s[start] +)Py"; + +constexpr const char *opVariadicEqualVariadicTemplate = R"Py( + return self.operation.{0}s[start:start + pg] +)Py"; + +constexpr const char *opVariadicSegmentTemplate = R"Py( + @property + def {0}(self): + return _get_segmented_{1}_range(self.operation, {2}) +)Py"; + +static llvm::cl::OptionCategory + clOpPythonBindingCat("Options for -gen-python-op-bindings"); + +static llvm::cl::opt + clDialectName("bind-dialect", + llvm::cl::desc("The dialect to run the generator for"), + llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat)); + +static bool isPythonKeyword(StringRef str) { + static llvm::StringSet<> keywords( + {"and", "as", "assert", "break", "class", "continue", + "def", "del", "elif", "else", "except", "finally", + "for", "from", "global", "if", "import", "in", + "is", "lambda", "nonlocal", "not", "or", "pass", + "raise", "return", "try", "while", "with", "yield"}); + return keywords.contains(str); +}; + +static std::string sanitizeName(StringRef name) { + if (isPythonKeyword(name)) + return (name + "_").str(); + return name.str(); +} + +static void emitElementAccessors( + const Operator &op, raw_ostream &os, const char *kind, + llvm::function_ref getNumVariadic, + llvm::function_ref getNumElements, + llvm::function_ref + getElement) { + std::string sameSizeTrait = + llvm::formatv("::mlir::OpTrait::SameVariadic{0}{1}Size", + llvm::StringRef(kind).take_front().upper(), + llvm::StringRef(kind).drop_front()); + std::string attrSizedTrait = + llvm::formatv("::mlir::OpTrait::AttrSized{0}{1}Segments", + llvm::StringRef(kind).take_front().upper(), + llvm::StringRef(kind).drop_front()); + + unsigned numVariadic = getNumVariadic(op); + + // If there is only one variadic element group, its size can be inferred from + // the total number of elements. If there are none, the generation is + // straightforward. + if (numVariadic <= 1) { + for (int i = 0, e = getNumElements(op); i < e; ++i) { + const NamedTypeConstraint &element = getElement(op, i); + if (element.name.empty()) + continue; + if (element.isVariableLength()) { + os << llvm::formatv(element.isOptional() ? opOneOptionalTemplate + : opOneVariadicTemplate, + sanitizeName(element.name), kind, + getNumElements(op), i); + } else { + os << llvm::formatv(opSingleTemplate, sanitizeName(element.name), kind, + i); + } + } + return; + } + + if (op.getTrait(sameSizeTrait)) { + int numPrecedingSimple = 0; + int numPrecedingVariadic = 0; + for (int i = 0, e = getNumElements(op); i < e; ++i) { + const NamedTypeConstraint &element = getElement(op, i); + if (!element.name.empty()) { + os << llvm::formatv(opVariadicEqualPrefixTemplate, + sanitizeName(element.name), kind, numVariadic, + numPrecedingSimple, numPrecedingVariadic); + os << llvm::formatv(element.isVariableLength() + ? opVariadicEqualVariadicTemplate + : opVariadicEqualSimpleTemplate, + kind); + } + if (element.isVariableLength()) + ++numPrecedingVariadic; + else + ++numPrecedingSimple; + } + return; + } + + if (op.getTrait(attrSizedTrait)) { + for (int i = 0, e = getNumElements(op); i < e; ++i) { + const NamedTypeConstraint &element = getElement(op, i); + if (element.name.empty()) + continue; + os << llvm::formatv(opVariadicSegmentTemplate, sanitizeName(element.name), + kind, i); + } + return; + } + + llvm::PrintFatalError("unsupported " + llvm::Twine(kind) + " structure"); +} + +static void emitOperandAccessors(const Operator &op, raw_ostream &os) { + auto getNumVariadic = [](const Operator &oper) { + return oper.getNumVariableLengthOperands(); + }; + auto getNumElements = [](const Operator &oper) { + return oper.getNumOperands(); + }; + auto getElement = [](const Operator &oper, + int i) -> const NamedTypeConstraint & { + return oper.getOperand(i); + }; + emitElementAccessors(op, os, "operand", getNumVariadic, getNumElements, + getElement); +} + +static void emitResultAccessors(const Operator &op, raw_ostream &os) { + auto getNumVariadic = [](const Operator &oper) { + return oper.getNumVariableLengthResults(); + }; + auto getNumElements = [](const Operator &oper) { + return oper.getNumResults(); + }; + auto getElement = [](const Operator &oper, int i) { + return oper.getResult(i); + }; + emitElementAccessors(op, os, "result", getNumVariadic, getNumElements, + getElement); +} + +static bool emitOpBindings(const Operator &op, raw_ostream &os) { + os << llvm::formatv(opClassTemplate, op.getCppClassName(), + op.getOperationName()); + emitOperandAccessors(op, os); + emitResultAccessors(op, os); + return false; +} + +static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) { + if (clDialectName.empty()) + llvm::PrintFatalError("dialect name not provided"); + + os << fileHeader; + os << llvm::formatv(segmentedListAccessorTemplate, "operand"); + os << llvm::formatv(segmentedListAccessorTemplate, "result"); + os << llvm::formatv(dialectClassTemplate, clDialectName.getValue()); + for (const llvm::Record *rec : records.getAllDerivedDefinitions("Op")) { + Operator op(rec); + if (op.getDialectName() != clDialectName.getValue()) + continue; + if (emitOpBindings(op, os)) + return true; + } + return false; +} + +static GenRegistration + genPythonBindings("gen-python-op-bindings", + "Generate Python bindings for MLIR Ops", &emitAllOps);