diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -973,15 +973,16 @@ propertiesDeleter = [](OpaqueProperties prop) { delete prop.as(); }; - propertiesSetter = [](OpaqueProperties new_prop, + propertiesSetter = [](OpaqueProperties newProp, const OpaqueProperties prop) { - *new_prop.as() = *prop.as(); + *newProp.as() = *prop.as(); }; propertiesId = TypeID::get(); } assert(propertiesId == TypeID::get() && "Inconsistent properties"); return *properties.as(); } + OpaqueProperties getRawProperties() { return properties; } // Set the properties defined on this OpState on the given operation, diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -2868,7 +2868,8 @@ py::arg("attributes") = py::none(), py::arg("successors") = py::none(), py::arg("regions") = 0, py::arg("loc") = py::none(), py::arg("ip") = py::none(), - py::arg("infer_type") = false, kOperationCreateDocstring) + py::arg("infer_type") = false, + kOperationCreateDocstring) .def_static( "parse", [](const std::string &sourceStr, const std::string &sourceName, diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -656,7 +656,7 @@ std::optional attributes, std::optional> successors, int regions, DefaultingPyLocation location, const pybind11::object &ip, - bool inferType = false); + bool inferType); /// Creates an OpView suitable for this operation. pybind11::object createOpView(); diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -15,10 +15,12 @@ #include "mlir/CAPI/Support.h" #include "mlir/CAPI/Utils.h" #include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Location.h" #include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" #include "mlir/IR/Verifier.h" @@ -26,6 +28,7 @@ #include "mlir/Parser/Parser.h" #include +#include #include using namespace mlir; @@ -345,25 +348,44 @@ if (!info) { emitError(state.location) << "type inference was requested for the operation " << state.name - << ", but the operation was not registered. Ensure that the dialect " + << ", but the operation was not registered; ensure that the dialect " "containing the operation is linked into MLIR and registered with " "the context"; return failure(); } - // Fallback to inference via an op interface. auto *inferInterface = info->getInterface(); if (!inferInterface) { emitError(state.location) << "type inference was requested for the operation " << state.name - << ", but the operation does not support type inference. Result " - "types must be specified explicitly."; + << ", but the operation does not support type inference; result " + "types must be specified explicitly"; + return failure(); + } + + DictionaryAttr attributes = state.attributes.getDictionary(context); + OpaqueProperties properties = state.getRawProperties(); + + if (!properties && info->getOpPropertyByteSize() > 0) { + auto prop = std::make_unique(info->getOpPropertyByteSize()); + properties = OpaqueProperties(prop.get()); + if (failed(info->setOpPropertiesFromAttribute(state.name, + properties, + attributes, nullptr))) { + return failure(); + } + + if (succeeded(inferInterface->inferReturnTypes( + context, state.location, state.operands, attributes, properties, + state.regions, state.types))) { + return success(); + } + // Diagnostic emitted by interface. return failure(); } if (succeeded(inferInterface->inferReturnTypes( - context, state.location, state.operands, - state.attributes.getDictionary(context), state.getRawProperties(), + context, state.location, state.operands, attributes, properties, state.regions, state.types))) return success(); @@ -405,8 +427,7 @@ return {nullptr}; } - MlirOperation result = wrap(Operation::create(cppState)); - return result; + return wrap(Operation::create(cppState)); } MlirOperation mlirOperationCreateParse(MlirContext context, diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -919,15 +919,8 @@ ConstShapeOp::Adaptor adaptor, SmallVectorImpl &inferredReturnTypes) { Builder b(context); const Properties prop = adaptor.getProperties(); - DenseIntElementsAttr shape; - // TODO: this is only exercised by the Python bindings codepath which does not - // support properties - shape = prop.shape ? prop.shape : - adaptor.getAttributes().getAs("shape"); - if (!shape) - return emitOptionalError(location, "missing shape attribute"); inferredReturnTypes.assign({RankedTensorType::get( - {static_cast(shape.size())}, b.getIndexType())}); + {static_cast(prop.shape.size())}, b.getIndexType())}); return success(); } diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp --- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp @@ -493,9 +493,7 @@ attributes = {{} regions = None {1} - super().__init__(self.build_generic( - attributes=attributes,{2} operands=operands, - successors=_ods_successors, regions=regions, loc=loc, ip=ip)) + super().__init__(self.build_generic({2})) )Py"; /// Template for appending a single element to the operand/result list. @@ -919,9 +917,20 @@ } functionArgs.push_back("loc=None"); functionArgs.push_back("ip=None"); + + SmallVector initArgs; + initArgs.push_back("attributes=attributes"); + if (!hasInferTypeInterface(op)) + initArgs.push_back("results=results"); + initArgs.push_back("operands=operands"); + initArgs.push_back("successors=_ods_successors"); + initArgs.push_back("regions=regions"); + initArgs.push_back("loc=loc"); + initArgs.push_back("ip=ip"); + os << llvm::formatv(initTemplate, llvm::join(functionArgs, ", "), llvm::join(builderLines, "\n "), - hasInferTypeInterface(op) ? "" : " results=results,"); + llvm::join(initArgs, ", ")); } static void emitSegmentSpec(