diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -225,6 +225,7 @@ MlirBlock *successors; intptr_t nAttributes; MlirNamedAttribute *attributes; + bool enableResultTypeInference; }; typedef struct MlirOperationState MlirOperationState; @@ -249,6 +250,14 @@ mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n, MlirNamedAttribute const *attributes); +/// Enables result type inference for the operation under construction. If +/// enabled, then the caller must not have called +/// mlirOperationStateAddResults(). Note that if enabled, the +/// mlirOperationCreate() call is failable: it will return a null operation +/// on inference failure and will emit diagnostics. +MLIR_CAPI_EXPORTED void +mlirOperationStateEnableResultTypeInference(MlirOperationState *state); + //===----------------------------------------------------------------------===// // Op Printing flags API. // While many of these are simple settings that could be represented in a @@ -293,8 +302,14 @@ //===----------------------------------------------------------------------===// /// Creates an operation and transfers ownership to the caller. -MLIR_CAPI_EXPORTED MlirOperation -mlirOperationCreate(const MlirOperationState *state); +/// Note that caller owned child objects are transferred in this call and must +/// not be further used. Particularly, this applies to any regions added to +/// the state (the implementation may invalidate any such pointers). +/// +/// This call can fail under the following conditions, in which case, it will +/// return a null operation and emit diagnostics: +/// - Result type inference is enabled and cannot be performed. +MLIR_CAPI_EXPORTED MlirOperation mlirOperationCreate(MlirOperationState *state); /// Takes an operation owned by the caller and destroys it. MLIR_CAPI_EXPORTED void mlirOperationDestroy(MlirOperation op); diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -18,6 +18,7 @@ #include "mlir/IR/Operation.h" #include "mlir/IR/Types.h" #include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Parser.h" using namespace mlir; @@ -188,6 +189,7 @@ state.successors = nullptr; state.nAttributes = 0; state.attributes = nullptr; + state.enableResultTypeInference = false; return state; } @@ -219,11 +221,47 @@ APPEND_ELEMS(MlirNamedAttribute, nAttributes, attributes); } +void mlirOperationStateEnableResultTypeInference(MlirOperationState *state) { + state->enableResultTypeInference = true; +} + //===----------------------------------------------------------------------===// // Operation API. //===----------------------------------------------------------------------===// -MlirOperation mlirOperationCreate(const MlirOperationState *state) { +static LogicalResult inferOperationTypes(OperationState &state) { + MLIRContext *context = state.getContext(); + const AbstractOperation *abstractOp = + AbstractOperation::lookup(state.name.getStringRef(), context); + if (!abstractOp) { + emitError(state.location) + << "type inference was requested for the operation " << state.name + << ", but the operation was not registered. Ensure that the dialect " + "containing the operation is linked into MLIR and registered with " + "the context"; + return failure(); + } + + // Fallback to inference via an op interface. + auto *inferInterface = abstractOp->getInterface(); + if (!inferInterface) { + emitError(state.location) + << "type inference was requested for the operation " << state.name + << ", but the operation does not support type inference. Result " + "types must be specified explicitly."; + return failure(); + } + + if (succeeded(inferInterface->inferReturnTypes( + context, state.location, state.operands, + state.attributes.getDictionary(context), state.regions, state.types))) + return success(); + + // Diagnostic emitted by interface. + return failure(); +} + +MlirOperation mlirOperationCreate(MlirOperationState *state) { assert(state); OperationState cppState(unwrap(state->location), unwrap(state->name)); SmallVector resultStorage; @@ -243,12 +281,21 @@ for (intptr_t i = 0; i < state->nRegions; ++i) cppState.addRegion(std::unique_ptr(unwrap(state->regions[i]))); - MlirOperation result = wrap(Operation::create(cppState)); free(state->results); free(state->operands); free(state->successors); free(state->regions); free(state->attributes); + + // Infer result types. + if (state->enableResultTypeInference) { + assert(cppState.types.empty() && + "result type inference enabled and result types provided"); + if (failed(inferOperationTypes(cppState))) + return {nullptr}; + } + + MlirOperation result = wrap(Operation::create(cppState)); return result; } diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -553,6 +553,35 @@ // clang-format on } +/// Creates operations with type inference and tests various failure modes. +static int createOperationWithTypeInference(MlirContext ctx) { + MlirLocation loc = mlirLocationUnknownGet(ctx); + MlirAttribute iAttr = mlirIntegerAttrGet(mlirIntegerTypeGet(ctx, 32), 4); + + // The shape.const_size op implements result type inference and is only used + // for that reason. + MlirOperationState state = mlirOperationStateGet( + mlirStringRefCreateFromCString("shape.const_size"), loc); + MlirNamedAttribute valueAttr = mlirNamedAttributeGet( + mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")), iAttr); + mlirOperationStateAddAttributes(&state, 1, &valueAttr); + mlirOperationStateEnableResultTypeInference(&state); + + // Expect result type inference to succeed. + MlirOperation op = mlirOperationCreate(&state); + if (mlirOperationIsNull(op)) { + fprintf(stderr, "ERROR: Result type inference unexpectedly failed"); + return 1; + } + + // CHECK: RESULT_TYPE_INFERENCE: !shape.size + fprintf(stderr, "RESULT_TYPE_INFERENCE: "); + mlirTypeDump(mlirValueGetType(mlirOperationGetResult(op, 0))); + fprintf(stderr, "\n"); + mlirOperationDestroy(op); + return 0; +} + /// Dumps instances of all builtin types to check that C API works correctly. /// Additionally, performs simple identity checks that a builtin type /// constructed with C API can be inspected and has the expected type. The @@ -957,14 +986,12 @@ (uint64_t *)mlirDenseElementsAttrGetRawData(uint64Elements); int64_t *int64RawData = (int64_t *)mlirDenseElementsAttrGetRawData(int64Elements); - float *floatRawData = - (float *)mlirDenseElementsAttrGetRawData(floatElements); + float *floatRawData = (float *)mlirDenseElementsAttrGetRawData(floatElements); double *doubleRawData = (double *)mlirDenseElementsAttrGetRawData(doubleElements); if (uint32RawData[0] != 0u || uint32RawData[1] != 1u || - int32RawData[0] != 0 || int32RawData[1] != 1 || - uint64RawData[0] != 0u || uint64RawData[1] != 1u || - int64RawData[0] != 0 || int64RawData[1] != 1 || + int32RawData[0] != 0 || int32RawData[1] != 1 || uint64RawData[0] != 0u || + uint64RawData[1] != 1u || int64RawData[0] != 0 || int64RawData[1] != 1 || floatRawData[0] != 0.0f || floatRawData[1] != 1.0f || doubleRawData[0] != 0.0 || doubleRawData[1] != 1.0) return 18; @@ -1389,19 +1416,21 @@ if (constructAndTraverseIr(ctx)) return 1; buildWithInsertionsAndPrint(ctx); + if (createOperationWithTypeInference(ctx)) + return 2; if (printBuiltinTypes(ctx)) - return 2; - if (printBuiltinAttributes(ctx)) return 3; - if (printAffineMap(ctx)) + if (printBuiltinAttributes(ctx)) return 4; - if (printAffineExpr(ctx)) + if (printAffineMap(ctx)) return 5; - if (affineMapFromExprs(ctx)) + if (printAffineExpr(ctx)) return 6; - if (registerOnlyStd()) + if (affineMapFromExprs(ctx)) return 7; + if (registerOnlyStd()) + return 8; mlirContextDestroy(ctx);