diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt --- a/mlir/CMakeLists.txt +++ b/mlir/CMakeLists.txt @@ -100,12 +100,6 @@ # from another directory like tools add_subdirectory(tools/mlir-tblgen) -# Create an anchor target that will depend on dialect-specific op bindings. -if (MLIR_BINDINGS_PYTHON_ENABLED) - add_custom_target(MLIRBindingsPythonIncGen) - include(AddMLIRPythonExtension) -endif() - add_subdirectory(include/mlir) add_subdirectory(lib) # C API needs all dialects for registration, but should be built before tests. diff --git a/mlir/cmake/modules/AddMLIRPythonExtension.cmake b/mlir/cmake/modules/AddMLIRPythonExtension.cmake --- a/mlir/cmake/modules/AddMLIRPythonExtension.cmake +++ b/mlir/cmake/modules/AddMLIRPythonExtension.cmake @@ -132,16 +132,10 @@ endfunction() -function(add_mlir_dialect_python_bindings filename dialectname) +function(add_mlir_dialect_python_bindings tblgen_target filename dialectname) set(LLVM_TARGET_DEFINITIONS ${filename}) mlir_tablegen("${dialectname}.py" -gen-python-op-bindings -bind-dialect=${dialectname}) - if (${ARGC} GREATER 2) - set(suffix ${ARGV2}) - else() - get_filename_component(suffix ${filename} NAME_WE) - endif() - set(tblgen_target "MLIRBindingsPython${suffix}") add_public_tablegen_target(${tblgen_target}) add_custom_command( @@ -150,6 +144,5 @@ COMMAND "${CMAKE_COMMAND}" -E copy_if_different "${CMAKE_CURRENT_BINARY_DIR}/${dialectname}.py" "${PROJECT_BINARY_DIR}/python/mlir/dialects/${dialectname}.py") - add_dependencies(MLIRBindingsPythonIncGen ${tblgen_target}) endfunction() diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/StandardOps/IR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/StandardOps/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/StandardOps/IR/CMakeLists.txt @@ -7,7 +7,3 @@ add_public_tablegen_target(MLIRStandardOpsIncGen) add_mlir_doc(Ops -gen-op-doc StandardOps Dialects/) - -if (MLIR_BINDINGS_PYTHON_ENABLED) - add_mlir_dialect_python_bindings(Ops.td std StandardOps) -endif() diff --git a/mlir/lib/Bindings/Python/Attributes.td b/mlir/lib/Bindings/Python/Attributes.td new file mode 100644 --- /dev/null +++ b/mlir/lib/Bindings/Python/Attributes.td @@ -0,0 +1,34 @@ +//===-- Attributes.td - Attribute mapping for Python -------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This defines the mapping between MLIR ODS attributes and the corresponding +// Python binding classes. +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_ATTRIBUTES +#define PYTHON_BINDINGS_ATTRIBUTES + +// A mapping between the attribute storage type and the corresponding Python +// type. There is not necessarily a 1-1 match for non-standard attributes. +class PythonAttr { + string cppStorageType = c; + string pythonType = p; +} + +// Mappings between supported standard attribtues and Python types. +def : PythonAttr<"::mlir::Attribute", "_ir.Attribute">; +def : PythonAttr<"::mlir::BoolAttr", "_ir.BoolAttr">; +def : PythonAttr<"::mlir::IntegerAttr", "_ir.IntegerAttr">; +def : PythonAttr<"::mlir::FloatAttr", "_ir.FloatAttr">; +def : PythonAttr<"::mlir::StringAttr", "_ir.StringAttr">; +def : PythonAttr<"::mlir::DenseElementsAttr", "_ir.DenseElementsAttr">; +def : PythonAttr<"::mlir::DenseIntElementsAttr", "_ir.DenseIntElementsAttr">; +def : PythonAttr<"::mlir::DenseFPElementsAttr", "_ir.DenseFPElementsAttr">; + +#endif diff --git a/mlir/lib/Bindings/Python/CMakeLists.txt b/mlir/lib/Bindings/Python/CMakeLists.txt --- a/mlir/lib/Bindings/Python/CMakeLists.txt +++ b/mlir/lib/Bindings/Python/CMakeLists.txt @@ -1,5 +1,15 @@ include(AddMLIRPythonExtension) add_custom_target(MLIRBindingsPythonExtension) + +################################################################################ +# Generate dialect-specific bindings. +################################################################################ + +add_mlir_dialect_python_bindings(MLIRBindingsPythonStandardOps + StandardOps.td + std) +add_dependencies(MLIRBindingsPythonExtension MLIRBindingsPythonStandardOps) + ################################################################################ # Copy python source tree. ################################################################################ @@ -19,8 +29,6 @@ ) add_dependencies(MLIRBindingsPythonExtension MLIRBindingsPythonSources) -add_dependencies(MLIRBindingsPythonExtension MLIRBindingsPythonIncGen) - foreach(PY_SRC_FILE ${PY_SRC_FILES}) set(PY_DEST_FILE "${PROJECT_BINARY_DIR}/python/${PY_SRC_FILE}") add_custom_command( diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -1310,8 +1310,14 @@ return mlirOperationGetNumAttributes(operation->get()); } + bool dunderContains(const std::string &name) { + return !mlirAttributeIsNull( + mlirOperationGetAttributeByName(operation->get(), name.c_str())); + } + static void bind(py::module &m) { py::class_(m, "OpAttributeMap") + .def("__contains__", &PyOpAttributeMap::dunderContains) .def("__len__", &PyOpAttributeMap::dunderLen) .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed) .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed); @@ -1747,6 +1753,24 @@ } }; +/// Unit Attribute subclass. Unit attributes don't have values. +class PyUnitAttribute : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit; + static constexpr const char *pyClassName = "UnitAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + return PyUnitAttribute(context->getRef(), + mlirUnitAttrGet(context->get())); + }, + py::arg("context") = py::none(), "Create a Unit attribute."); + } +}; + } // namespace //------------------------------------------------------------------------------ @@ -2852,6 +2876,7 @@ PyDenseElementsAttribute::bind(m); PyDenseIntElementsAttribute::bind(m); PyDenseFPElementsAttribute::bind(m); + PyUnitAttribute::bind(m); //---------------------------------------------------------------------------- // Mapping of PyType. diff --git a/mlir/lib/Bindings/Python/StandardOps.td b/mlir/lib/Bindings/Python/StandardOps.td new file mode 100644 --- /dev/null +++ b/mlir/lib/Bindings/Python/StandardOps.td @@ -0,0 +1,20 @@ +//===-- StandardOps.td - Entry point for StandardOps bind --*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the main file from which the Python bindings for the Standard +// dialect are generated. +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_STANDARD_OPS +#define PYTHON_BINDINGS_STANDARD_OPS + +include "mlir/Dialect/StandardOps/IR/Ops.td" +include "Attributes.td" + +#endif diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp --- a/mlir/lib/TableGen/Operator.cpp +++ b/mlir/lib/TableGen/Operator.cpp @@ -147,7 +147,7 @@ StringRef Operator::getArgName(int index) const { DagInit *argumentValues = def.getValueAsDag("arguments"); - return argumentValues->getArgName(index)->getValue(); + return argumentValues->getArgNameStr(index); } auto Operator::getArgDecorators(int index) const -> var_decorator_range { diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td --- a/mlir/test/mlir-tblgen/op-python-bindings.td +++ b/mlir/test/mlir-tblgen/op-python-bindings.td @@ -1,6 +1,7 @@ -// RUN: mlir-tblgen -gen-python-op-bindings -bind-dialect=test -I %S/../../include %s | FileCheck %s +// RUN: mlir-tblgen -gen-python-op-bindings -bind-dialect=test -I %S/../../include -I %S/../../lib/Bindings/Python %s | FileCheck %s include "mlir/IR/OpBase.td" +include "Attributes.td" // CHECK: @_cext.register_dialect // CHECK: class _Dialect(_ir.Dialect): @@ -105,6 +106,75 @@ Optional:$variadic2); } + +// CHECK: @_cext.register_operation(_Dialect) +// CHECK: class AttributedOp(_ir.OpView): +// CHECK-LABEL: OPERATION_NAME = "test.attributed_op" +def AttributedOp : TestOp<"attributed_op"> { + // CHECK: def __init__(self, i32attr, optionalF32Attr, unitAttr, in_, loc=None, ip=None): + // CHECK: operands = [] + // CHECK: results = [] + // CHECK: attributes = {} + // CHECK: attributes["i32attr"] = i32attr + // CHECK: if optionalF32Attr is not None: attributes["optionalF32Attr"] = optionalF32Attr + // CHECK: if bool(unitAttr): attributes["unitAttr"] = _ir.UnitAttr.get( + // CHECK: _ir.Location.current.context if loc is None else loc.context) + // CHECK: attributes["in"] = in_ + // CHECK: super().__init__(_ir.Operation.create( + // CHECK: "test.attributed_op", attributes=attributes, operands=operands, results=results, + // CHECK: loc=loc, ip=ip)) + + // CHECK: @property + // CHECK: def i32attr(self): + // CHECK: return _ir.IntegerAttr(self.operation.attributes["i32attr"]) + + // CHECK: @property + // CHECK: def optionalF32Attr(self): + // CHECK: if "optionalF32Attr" not in self.operation.attributes: + // CHECK: return None + // CHECK: return _ir.FloatAttr(self.operation.attributes["optionalF32Attr"]) + + // CHECK: @property + // CHECK: def unitAttr(self): + // CHECK: return "unitAttr" in self.operation.attributes + + // CHECK: @property + // CHECK: def in_(self): + // CHECK: return _ir.IntegerAttr(self.operation.attributes["in"]) + let arguments = (ins I32Attr:$i32attr, OptionalAttr:$optionalF32Attr, + UnitAttr:$unitAttr, I32Attr:$in); +} + +// CHECK: @_cext.register_operation(_Dialect) +// CHECK: class AttributedOpWithOperands(_ir.OpView): +// CHECK-LABEL: OPERATION_NAME = "test.attributed_op_with_operands" +def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> { + // CHECK: def __init__(self, _gen_arg_0, in_, _gen_arg_2, is_, loc=None, ip=None): + // CHECK: operands = [] + // CHECK: results = [] + // CHECK: attributes = {} + // CHECK: operands.append(_gen_arg_0) + // CHECK: operands.append(_gen_arg_2) + // CHECK: if bool(in_): attributes["in"] = _ir.UnitAttr.get( + // CHECK: _ir.Location.current.context if loc is None else loc.context) + // CHECK: if is_ is not None: attributes["is"] = is_ + // CHECK: super().__init__(_ir.Operation.create( + // CHECK: "test.attributed_op_with_operands", attributes=attributes, operands=operands, results=results, + // CHECK: loc=loc, ip=ip)) + + // CHECK: @property + // CHECK: def in_(self): + // CHECK: return "in" in self.operation.attributes + + // CHECK: @property + // CHECK: def is_(self): + // CHECK: if "is" not in self.operation.attributes: + // CHECK: return None + // CHECK: return _ir.FloatAttr(self.operation.attributes["is"]) + let arguments = (ins I32, UnitAttr:$in, F32, OptionalAttr:$is); +} + + // CHECK: @_cext.register_operation(_Dialect) // CHECK: class EmptyOp(_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.empty" diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp --- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp @@ -145,6 +145,39 @@ constexpr const char *opVariadicSegmentOptionalTrailingTemplate = R"Py([0] if len({0}_range) > 0 else None)Py"; +/// Template for an operation attribute getter: +/// {0} is the name of the attribute sanitized for Python; +/// {1} is the Python type of the attribute; +/// {2} os the original name of the attribute. +constexpr const char *attributeGetterTemplate = R"Py( + @property + def {0}(self): + return {1}(self.operation.attributes["{2}"]) +)Py"; + +/// Template for an optional operation attribute getter: +/// {0} is the name of the attribute sanitized for Python; +/// {1} is the Python type of the attribute; +/// {2} is the original name of the attribute. +constexpr const char *optionalAttributeGetterTemplate = R"Py( + @property + def {0}(self): + if "{2}" not in self.operation.attributes: + return None + return {1}(self.operation.attributes["{2}"]) +)Py"; + +/// Template for a accessing a unit operation attribute, returns True of the +/// unit attribute is present, False otherwise (unit attributes have meaning +/// by mere presence): +/// {0} is the name of the attribute sanitized for Python, +/// {1} is the original name of the attribute. +constexpr const char *unitAttributeGetterTemplate = R"Py( + @property + def {0}(self): + return "{1}" in self.operation.attributes +)Py"; + static llvm::cl::OptionCategory clOpPythonBindingCat("Options for -gen-python-op-bindings"); @@ -153,6 +186,8 @@ llvm::cl::desc("The dialect to run the generator for"), llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat)); +using AttributeClasses = DenseMap; + /// Checks whether `str` is a Python keyword. static bool isPythonKeyword(StringRef str) { static llvm::StringSet<> keywords( @@ -285,7 +320,7 @@ return op.getResult(i); } -/// Emits accessor to Op operands. +/// Emits accessors to Op operands. static void emitOperandAccessors(const Operator &op, raw_ostream &os) { auto getNumVariadic = [](const Operator &oper) { return oper.getNumVariableLengthOperands(); @@ -294,7 +329,7 @@ getOperand); } -/// Emits access or Op results. +/// Emits accessors Op results. static void emitResultAccessors(const Operator &op, raw_ostream &os) { auto getNumVariadic = [](const Operator &oper) { return oper.getNumVariableLengthResults(); @@ -303,6 +338,39 @@ getResult); } +/// Emits accessors to Op attributes. +static void emitAttributeAccessors(const Operator &op, + const AttributeClasses &attributeClasses, + raw_ostream &os) { + for (const auto &namedAttr : op.getAttributes()) { + // Skip "derived" attributes because they are just C++ functions that we + // don't currently expose. + if (namedAttr.attr.isDerivedAttr()) + continue; + + if (namedAttr.name.empty()) + continue; + + // Unit attributes are handled specially. + if (namedAttr.attr.getStorageType().trim().equals("::mlir::UnitAttr")) { + os << llvm::formatv(unitAttributeGetterTemplate, + sanitizeName(namedAttr.name), namedAttr.name); + continue; + } + + // Other kinds of attributes need a mapping to a Python type. + if (!attributeClasses.count(namedAttr.attr.getStorageType().trim())) + continue; + + os << llvm::formatv( + namedAttr.attr.isOptional() ? optionalAttributeGetterTemplate + : attributeGetterTemplate, + sanitizeName(namedAttr.name), + attributeClasses.lookup(namedAttr.attr.getStorageType()), + namedAttr.name); + } +} + /// Template for the default auto-generated builder. /// {0} is the operation name; /// {1} is a comma-separated list of builder arguments, including the trailing @@ -362,14 +430,82 @@ constexpr const char *variadicSegmentTemplate = "{0}_segment_sizes.append(len({1}))"; -/// Populates `builderArgs` with the list of `__init__` arguments that -/// correspond to either operands or results of `op`, and `builderLines` with -/// additional lines that are required in the builder. `kind` must be either -/// "operand" or "result". `unnamedTemplate` is used to generate names for -/// operands or results that don't have the name in ODS. +/// Template for setting an attribute in the operation builder. +/// {0} is the attribute name; +/// {1} is the builder argument name. +constexpr const char *initAttributeTemplate = R"Py(attributes["{0}"] = {1})Py"; + +/// Template for setting an optional attribute in the operation builder. +/// {0} is the attribute name; +/// {1} is the builder argument name. +constexpr const char *initOptionalAttributeTemplate = + R"Py(if {1} is not None: attributes["{0}"] = {1})Py"; + +constexpr const char *initUnitAttributeTemplate = + R"Py(if bool({1}): attributes["{0}"] = _ir.UnitAttr.get( + _ir.Location.current.context if loc is None else loc.context))Py"; + +/// Populates `builderArgs` with the Python-compatible names of builder function +/// arguments, first the results, then the intermixed attributes and operands in +/// the same order as they appear in the `arguments` field of the op definition. +/// Additionally, `operandNames` is populated with names of operands in their +/// order of appearance. +static void +populateBuilderArgs(const Operator &op, + llvm::SmallVectorImpl &builderArgs, + llvm::SmallVectorImpl &operandNames) { + for (int i = 0, e = op.getNumResults(); i < e; ++i) { + std::string name = op.getResultName(i).str(); + if (name.empty()) + name = llvm::formatv("_gen_res_{0}", i); + name = sanitizeName(name); + builderArgs.push_back(name); + } + for (int i = 0, e = op.getNumArgs(); i < e; ++i) { + std::string name = op.getArgName(i).str(); + if (name.empty()) + name = llvm::formatv("_gen_arg_{0}", i); + name = sanitizeName(name); + builderArgs.push_back(name); + if (!op.getArg(i).is()) + operandNames.push_back(name); + } +} + +/// Populates `builderLines` with additional lines that are required in the +/// builder to set up operation attributes. `argNames` is expected to contain +/// the names of builder arguments that correspond to op arguments, i.e. to the +/// operands and attributes in the same order as they appear in the `arguments` +/// field. +static void +populateBuilderLinesAttr(const Operator &op, + llvm::ArrayRef argNames, + llvm::SmallVectorImpl &builderLines) { + for (int i = 0, e = op.getNumArgs(); i < e; ++i) { + Argument arg = op.getArg(i); + auto *attribute = arg.dyn_cast(); + if (!attribute) + continue; + + // Unit attributes are handled specially. + if (attribute->attr.getStorageType().trim().equals("::mlir::UnitAttr")) { + builderLines.push_back(llvm::formatv(initUnitAttributeTemplate, + attribute->name, argNames[i])); + continue; + } + + builderLines.push_back(llvm::formatv(attribute->attr.isOptional() + ? initOptionalAttributeTemplate + : initAttributeTemplate, + attribute->name, argNames[i])); + } +} + +/// Populates `builderLines` with additional lines that are required in the +/// builder. `kind` must be either "operand" or "result". `names` contains the +/// names of init arguments that correspond to the elements. static void populateBuilderLines( - const Operator &op, const char *kind, const char *unnamedTemplate, - llvm::SmallVectorImpl &builderArgs, + const Operator &op, const char *kind, llvm::ArrayRef names, llvm::SmallVectorImpl &builderLines, llvm::function_ref getNumElements, llvm::function_ref @@ -383,11 +519,7 @@ // For each element, find or generate a name. for (int i = 0, e = getNumElements(op); i < e; ++i) { const NamedTypeConstraint &element = getElement(op, i); - std::string name = element.name.str(); - if (name.empty()) - name = llvm::formatv(unnamedTemplate, i).str(); - name = sanitizeName(name); - builderArgs.push_back(name); + std::string name = names[i]; // Choose the formatting string based on the element kind. llvm::StringRef formatString, segmentFormatString; @@ -417,21 +549,25 @@ /// Emits a default builder constructing an operation from the list of its /// result types, followed by a list of its operands. static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) { - // TODO: support attribute types. - if (op.getNumNativeAttributes() != 0) - return; - // If we are asked to skip default builders, comply. if (op.skipDefaultBuilders()) return; llvm::SmallVector builderArgs; llvm::SmallVector builderLines; - builderArgs.reserve(op.getNumOperands() + op.getNumResults()); - populateBuilderLines(op, "result", "_gen_res_{0}", builderArgs, builderLines, - getNumResults, getResult); - populateBuilderLines(op, "operand", "_gen_arg_{0}", builderArgs, builderLines, + llvm::SmallVector operandArgNames; + builderArgs.reserve(op.getNumOperands() + op.getNumResults() + + op.getNumNativeAttributes()); + populateBuilderArgs(op, builderArgs, operandArgNames); + populateBuilderLines( + op, "result", + llvm::makeArrayRef(builderArgs).take_front(op.getNumResults()), + builderLines, getNumResults, getResult); + populateBuilderLines(op, "operand", operandArgNames, builderLines, getNumOperands, getOperand); + populateBuilderLinesAttr( + op, llvm::makeArrayRef(builderArgs).drop_front(op.getNumResults()), + builderLines); builderArgs.push_back("loc=None"); builderArgs.push_back("ip=None"); @@ -440,12 +576,24 @@ llvm::join(builderLines, "\n ")); } +static void constructAttributeMapping(const llvm::RecordKeeper &records, + AttributeClasses &attributeClasses) { + for (const llvm::Record *rec : + records.getAllDerivedDefinitions("PythonAttr")) { + attributeClasses.try_emplace(rec->getValueAsString("cppStorageType").trim(), + rec->getValueAsString("pythonType").trim()); + } +} + /// Emits bindings for a specific Op to the given output stream. -static void emitOpBindings(const Operator &op, raw_ostream &os) { +static void emitOpBindings(const Operator &op, + const AttributeClasses &attributeClasses, + raw_ostream &os) { os << llvm::formatv(opClassTemplate, op.getCppClassName(), op.getOperationName()); emitDefaultOpBuilder(op, os); emitOperandAccessors(op, os); + emitAttributeAccessors(op, attributeClasses, os); emitResultAccessors(op, os); } @@ -456,12 +604,15 @@ if (clDialectName.empty()) llvm::PrintFatalError("dialect name not provided"); + AttributeClasses attributeClasses; + constructAttributeMapping(records, attributeClasses); + os << fileHeader; os << llvm::formatv(dialectClassTemplate, clDialectName.getValue()); for (const llvm::Record *rec : records.getAllDerivedDefinitions("Op")) { Operator op(rec); if (op.getDialectName() == clDialectName.getValue()) - emitOpBindings(op, os); + emitOpBindings(op, attributeClasses, os); } return false; }