diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -15,10 +15,12 @@ #include "mlir/CAPI/Support.h" #include "mlir/CAPI/Utils.h" #include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Location.h" #include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" #include "mlir/IR/Verifier.h" @@ -26,6 +28,7 @@ #include "mlir/Parser/Parser.h" #include +#include #include using namespace mlir; @@ -345,25 +348,43 @@ if (!info) { emitError(state.location) << "type inference was requested for the operation " << state.name - << ", but the operation was not registered. Ensure that the dialect " + << ", but the operation was not registered; ensure that the dialect " "containing the operation is linked into MLIR and registered with " "the context"; return failure(); } - // Fallback to inference via an op interface. auto *inferInterface = info->getInterface(); if (!inferInterface) { emitError(state.location) << "type inference was requested for the operation " << state.name - << ", but the operation does not support type inference. Result " - "types must be specified explicitly."; + << ", but the operation does not support type inference; result " + "types must be specified explicitly"; + return failure(); + } + + DictionaryAttr attributes = state.attributes.getDictionary(context); + OpaqueProperties properties = state.getRawProperties(); + + if (!properties && info->getOpPropertyByteSize() > 0 && !attributes.empty()) { + auto prop = std::make_unique(info->getOpPropertyByteSize()); + properties = OpaqueProperties(prop.get()); + if (failed(info->setOpPropertiesFromAttribute(state.name, properties, + attributes, nullptr))) { + return failure(); + } + + if (succeeded(inferInterface->inferReturnTypes( + context, state.location, state.operands, attributes, properties, + state.regions, state.types))) { + return success(); + } + // Diagnostic emitted by interface. return failure(); } if (succeeded(inferInterface->inferReturnTypes( - context, state.location, state.operands, - state.attributes.getDictionary(context), state.getRawProperties(), + context, state.location, state.operands, attributes, properties, state.regions, state.types))) return success(); @@ -405,8 +426,7 @@ return {nullptr}; } - MlirOperation result = wrap(Operation::create(cppState)); - return result; + return wrap(Operation::create(cppState)); } MlirOperation mlirOperationCreateParse(MlirContext context, diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -919,15 +919,8 @@ ConstShapeOp::Adaptor adaptor, SmallVectorImpl &inferredReturnTypes) { Builder b(context); const Properties prop = adaptor.getProperties(); - DenseIntElementsAttr shape; - // TODO: this is only exercised by the Python bindings codepath which does not - // support properties - shape = prop.shape ? prop.shape : - adaptor.getAttributes().getAs("shape"); - if (!shape) - return emitOptionalError(location, "missing shape attribute"); inferredReturnTypes.assign({RankedTensorType::get( - {static_cast(shape.size())}, b.getIndexType())}); + {static_cast(prop.shape.size())}, b.getIndexType())}); return success(); }