Changeset View
Changeset View
Standalone View
Standalone View
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
//===- OpPythonBindingGen.cpp - Generator of Python API for MLIR Ops ------===// | //===- OpPythonBindingGen.cpp - Generator of Python API for MLIR Ops ------===// | ||||
// | // | ||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||||
// See https://llvm.org/LICENSE.txt for license information. | // See https://llvm.org/LICENSE.txt for license information. | ||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||||
// | // | ||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
// | // | ||||
// OpPythonBindingGen uses ODS specification of MLIR ops to generate Python | // OpPythonBindingGen uses ODS specification of MLIR ops to generate Python | ||||
// binding classes wrapping a generic operation API. | // binding classes wrapping a generic operation API. | ||||
// | // | ||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
#include "mlir/Support/LogicalResult.h" | |||||
#include "mlir/TableGen/GenInfo.h" | #include "mlir/TableGen/GenInfo.h" | ||||
#include "mlir/TableGen/Operator.h" | #include "mlir/TableGen/Operator.h" | ||||
#include "llvm/ADT/StringSet.h" | #include "llvm/ADT/StringSet.h" | ||||
#include "llvm/Support/CommandLine.h" | #include "llvm/Support/CommandLine.h" | ||||
#include "llvm/Support/FormatVariadic.h" | #include "llvm/Support/FormatVariadic.h" | ||||
#include "llvm/TableGen/Error.h" | #include "llvm/TableGen/Error.h" | ||||
#include "llvm/TableGen/Record.h" | #include "llvm/TableGen/Record.h" | ||||
▲ Show 20 Lines • Show All 515 Lines • ▼ Show 20 Lines | |||||
constexpr const char *initAttributeTemplate = R"Py(attributes["{0}"] = {1})Py"; | constexpr const char *initAttributeTemplate = R"Py(attributes["{0}"] = {1})Py"; | ||||
/// Template for setting an optional attribute in the operation builder. | /// Template for setting an optional attribute in the operation builder. | ||||
/// {0} is the attribute name; | /// {0} is the attribute name; | ||||
/// {1} is the builder argument name. | /// {1} is the builder argument name. | ||||
constexpr const char *initOptionalAttributeTemplate = | constexpr const char *initOptionalAttributeTemplate = | ||||
R"Py(if {1} is not None: attributes["{0}"] = {1})Py"; | R"Py(if {1} is not None: attributes["{0}"] = {1})Py"; | ||||
/// Template for setting an attribute with a default value in the operation | |||||
/// builder. | |||||
/// {0} is the attribute name; | |||||
/// {1} is the builder argument name; | |||||
/// {2} is the default value. | |||||
constexpr const char *initDefaultValuedAttributeTemplate = | |||||
R"Py(attributes["{0}"] = {1} if {1} is not None else {2})Py"; | |||||
/// Template for asserting that an attribute value was provided when calling a | |||||
/// builder. | |||||
/// {0} is the attribute name; | |||||
/// {1} is the builder argument name. | |||||
constexpr const char *assertAttributeValueSpecified = | |||||
R"Py(assert {1} is not None, "attribute {0} must be specified")Py"; | |||||
constexpr const char *initUnitAttributeTemplate = | constexpr const char *initUnitAttributeTemplate = | ||||
R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get( | R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get( | ||||
_ods_get_default_loc_context(loc)))Py"; | _ods_get_default_loc_context(loc)))Py"; | ||||
/// Template to initialize the successors list in the builder if there are any | /// Template to initialize the successors list in the builder if there are any | ||||
/// successors. | /// successors. | ||||
/// {0} is the value to initialize the successors list to. | /// {0} is the value to initialize the successors list to. | ||||
constexpr const char *initSuccessorsTemplate = R"Py(_ods_successors = {0})Py"; | constexpr const char *initSuccessorsTemplate = R"Py(_ods_successors = {0})Py"; | ||||
▲ Show 20 Lines • Show All 89 Lines • ▼ Show 20 Lines | for (int i = 0, e = op.getNumSuccessors(); i < e; ++i) { | ||||
if (name.empty()) | if (name.empty()) | ||||
name = llvm::formatv("_gen_successor_{0}", i); | name = llvm::formatv("_gen_successor_{0}", i); | ||||
name = sanitizeName(name); | name = sanitizeName(name); | ||||
builderArgs.push_back(name); | builderArgs.push_back(name); | ||||
successorArgNames.push_back(name); | successorArgNames.push_back(name); | ||||
} | } | ||||
} | } | ||||
/// Generates Python code for the default value of the given attribute. | |||||
static FailureOr<std::string> getAttributeDefaultValue(Attribute attr) { | |||||
assert(attr.hasDefaultValue() && "expected attribute with default value"); | |||||
StringRef storageType = attr.getStorageType().trim(); | |||||
StringRef defaultValCpp = attr.getDefaultValue().trim(); | |||||
// A list of commonly used attribute types and default values for which | |||||
// we can generate Python code. Extend as needed. | |||||
if (storageType.equals("::mlir::ArrayAttr") && defaultValCpp.equals("{}")) | |||||
return std::string("_ods_ir.ArrayAttr.get([])"); | |||||
// No match: Cannot generate Python code. | |||||
return failure(); | |||||
} | |||||
/// Populates `builderLines` with additional lines that are required in the | /// Populates `builderLines` with additional lines that are required in the | ||||
/// builder to set up operation attributes. `argNames` is expected to contain | /// builder to set up operation attributes. `argNames` is expected to contain | ||||
/// the names of builder arguments that correspond to op arguments, i.e. to the | /// the names of builder arguments that correspond to op arguments, i.e. to the | ||||
/// operands and attributes in the same order as they appear in the `arguments` | /// operands and attributes in the same order as they appear in the `arguments` | ||||
/// field. | /// field. | ||||
static void | static void | ||||
populateBuilderLinesAttr(const Operator &op, | populateBuilderLinesAttr(const Operator &op, | ||||
llvm::ArrayRef<std::string> argNames, | llvm::ArrayRef<std::string> argNames, | ||||
llvm::SmallVectorImpl<std::string> &builderLines) { | llvm::SmallVectorImpl<std::string> &builderLines) { | ||||
for (int i = 0, e = op.getNumArgs(); i < e; ++i) { | for (int i = 0, e = op.getNumArgs(); i < e; ++i) { | ||||
Argument arg = op.getArg(i); | Argument arg = op.getArg(i); | ||||
auto *attribute = arg.dyn_cast<NamedAttribute *>(); | auto *attribute = arg.dyn_cast<NamedAttribute *>(); | ||||
if (!attribute) | if (!attribute) | ||||
continue; | continue; | ||||
// Unit attributes are handled specially. | // Unit attributes are handled specially. | ||||
if (attribute->attr.getStorageType().trim().equals("::mlir::UnitAttr")) { | if (attribute->attr.getStorageType().trim().equals("::mlir::UnitAttr")) { | ||||
builderLines.push_back(llvm::formatv(initUnitAttributeTemplate, | builderLines.push_back(llvm::formatv(initUnitAttributeTemplate, | ||||
attribute->name, argNames[i])); | attribute->name, argNames[i])); | ||||
continue; | continue; | ||||
} | } | ||||
// Attributes with default value are handled specially. | builderLines.push_back(llvm::formatv( | ||||
if (attribute->attr.hasDefaultValue()) { | (attribute->attr.isOptional() || attribute->attr.hasDefaultValue()) | ||||
// In case we cannot generate Python code for the default value, the | |||||
// attribute must be specified by the user. | |||||
FailureOr<std::string> defaultValPy = | |||||
getAttributeDefaultValue(attribute->attr); | |||||
if (succeeded(defaultValPy)) { | |||||
builderLines.push_back(llvm::formatv(initDefaultValuedAttributeTemplate, | |||||
attribute->name, argNames[i], | |||||
*defaultValPy)); | |||||
} else { | |||||
builderLines.push_back(llvm::formatv(assertAttributeValueSpecified, | |||||
attribute->name, argNames[i])); | |||||
builderLines.push_back( | |||||
llvm::formatv(initAttributeTemplate, attribute->name, argNames[i])); | |||||
} | |||||
continue; | |||||
} | |||||
builderLines.push_back(llvm::formatv(attribute->attr.isOptional() | |||||
? initOptionalAttributeTemplate | ? initOptionalAttributeTemplate | ||||
: initAttributeTemplate, | : initAttributeTemplate, | ||||
attribute->name, argNames[i])); | attribute->name, argNames[i])); | ||||
} | } | ||||
} | } | ||||
/// Populates `builderLines` with additional lines that are required in the | /// Populates `builderLines` with additional lines that are required in the | ||||
/// builder to set up successors. successorArgNames is expected to correspond | /// builder to set up successors. successorArgNames is expected to correspond | ||||
/// to the Python argument name for each successor on the op. | /// to the Python argument name for each successor on the op. | ||||
static void populateBuilderLinesSuccessors( | static void populateBuilderLinesSuccessors( | ||||
const Operator &op, llvm::ArrayRef<std::string> successorArgNames, | const Operator &op, llvm::ArrayRef<std::string> successorArgNames, | ||||
▲ Show 20 Lines • Show All 360 Lines • Show Last 20 Lines |