diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -276,6 +276,7 @@ intptr_t nAttributes; MlirNamedAttribute *attributes; bool enableResultTypeInference; + bool enablePopulateDefaultAttributes; }; typedef struct MlirOperationState MlirOperationState; @@ -308,6 +309,12 @@ MLIR_CAPI_EXPORTED void mlirOperationStateEnableResultTypeInference(MlirOperationState *state); +/// Enables populating default attributes for the operation under construction. +/// If enabled, then the caller must not have called +/// mlirOperationStateAddResults(). +MLIR_CAPI_EXPORTED void +mlirOperationStateEnablePopulateDefaultAttributes(MlirOperationState *state); + //===----------------------------------------------------------------------===// // Op Printing flags API. // While many of these are simple settings that could be represented in a diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1212,6 +1212,10 @@ mlirRegions.data()); } + // FIXME: Just to show how it would be used Python side. I'd want to go and + // remove the special casing for known default attribute types while here. + state.enablePopulateDefaultAttributes = true; + // Construct the operation. MlirOperation operation = mlirOperationCreate(&state); PyOperationRef created = diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -242,6 +242,7 @@ state.nAttributes = 0; state.attributes = nullptr; state.enableResultTypeInference = false; + state.enablePopulateDefaultAttributes = false; return state; } @@ -312,6 +313,22 @@ return failure(); } +static void populateDefaultAttributes(OperationState &state) { + Optional info = state.name.getRegisteredInfo(); + if (!info) { + // Treat as warning only if non-registered operation. + emitWarning(state.location) + << "populating default attributes was requested for the operation " + << state.name + << ", but the operation was not registered. Ensure that the dialect " + "containing the operation is linked into MLIR and registered with " + "the context"; + return; + } + + info->populateDefaultAttrs(state.attributes); +} + MlirOperation mlirOperationCreate(MlirOperationState *state) { assert(state); OperationState cppState(unwrap(state->location), unwrap(state->name)); @@ -338,6 +355,10 @@ free(state->regions); free(state->attributes); + // Populate default attributes state. + if (state->enablePopulateDefaultAttributes) + populateDefaultAttributes(cppState); + // Infer result types. if (state->enableResultTypeInference) { assert(cppState.types.empty() && diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -617,6 +617,39 @@ return 0; } +/// Creates operations with default attributes populated. +static int createOperationWithDefaultAttrs(MlirContext ctx) { + MlirLocation loc = mlirLocationUnknownGet(ctx); + + // The scf.foreach_thread has default valued attributes and is only used for + // that reason. + MlirOperationState state = mlirOperationStateGet( + mlirStringRefCreateFromCString("scf.foreach_thread"), loc); + + // Create without and with default attributes populates. + state.enablePopulateDefaultAttributes = false; + MlirOperation opNodef = mlirOperationCreate(&state); + state.enablePopulateDefaultAttributes = true; + MlirOperation opDef = mlirOperationCreate(&state); + + if (mlirOperationIsNull(opDef) || mlirOperationIsNull(opNodef)) { + fprintf(stderr, "ERROR: Result op creation unexpectedly failed"); + return 1; + } + + // CHECK: NOPOPULATE_DEFAULT: "scf.foreach_thread"() : () -> () + fprintf(stderr, "NOPOPULATE_DEFAULT: "); + mlirOperationDump(opNodef); + // CHECK: POPULATE_DEFAULT: "scf.foreach_thread"() + // CHECK-SAME: {thread_dim_mapping = []} : () -> () + fprintf(stderr, "POPULATE_DEFAULT: "); + mlirOperationDump(opDef); + fprintf(stderr, "\n"); + mlirOperationDestroy(opDef); + mlirOperationDestroy(opNodef); + return 0; +} + /// Dumps instances of all builtin types to check that C API works correctly. /// Additionally, performs simple identity checks that a builtin type /// constructed with C API can be inspected and has the expected type. The @@ -2042,33 +2075,34 @@ buildWithInsertionsAndPrint(ctx); if (createOperationWithTypeInference(ctx)) return 2; - - if (printBuiltinTypes(ctx)) + if (createOperationWithDefaultAttrs(ctx)) return 3; - if (printBuiltinAttributes(ctx)) + if (printBuiltinTypes(ctx)) return 4; - if (printAffineMap(ctx)) + if (printBuiltinAttributes(ctx)) return 5; - if (printAffineExpr(ctx)) + if (printAffineMap(ctx)) return 6; - if (affineMapFromExprs(ctx)) + if (printAffineExpr(ctx)) return 7; - if (printIntegerSet(ctx)) + if (affineMapFromExprs(ctx)) return 8; - if (registerOnlyStd()) + if (printIntegerSet(ctx)) return 9; - if (testBackreferences()) + if (registerOnlyStd()) return 10; - if (testOperands()) + if (testBackreferences()) return 11; - if (testClone()) + if (testOperands()) return 12; - if (testTypeID(ctx)) + if (testClone()) return 13; - if (testSymbolTable(ctx)) + if (testTypeID(ctx)) return 14; - if (testDialectRegistry()) + if (testSymbolTable(ctx)) return 15; + if (testDialectRegistry()) + return 16; mlirContextDestroy(ctx);