diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -194,6 +194,9 @@ MLIR_CAPI_EXPORTED MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module); +/// Creates a deep copy of a module +MLIR_CAPI_EXPORTED MlirModule mlirModuleClone(MlirModule module); + /// Gets the context that a module was created with. MLIR_CAPI_EXPORTED MlirContext mlirModuleGetContext(MlirModule module); @@ -326,6 +329,9 @@ /// - Result type inference is enabled and cannot be performed. MLIR_CAPI_EXPORTED MlirOperation mlirOperationCreate(MlirOperationState *state); +/// Creates a deep copy of an operatioon +MLIR_CAPI_EXPORTED MlirOperation mlirOperationClone(MlirOperation module); + /// Takes an operation owned by the caller and destroys it. MLIR_CAPI_EXPORTED void mlirOperationDestroy(MlirOperation op); diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -164,6 +164,10 @@ return MlirModule{owning.release().getOperation()}; } +MlirModule mlirModuleClone(MlirModule module) { + return wrap(unwrap(module).clone()); +} + MlirContext mlirModuleGetContext(MlirModule module) { return wrap(unwrap(module).getContext()); } @@ -313,6 +317,10 @@ return result; } +MlirOperation mlirOperationClone(MlirOperation op) { + return wrap(unwrap(op)->clone()); +} + void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); } bool mlirOperationEqual(MlirOperation op, MlirOperation other) { diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -1619,6 +1619,61 @@ return 0; } +intptr_t countOperations(MlirBlock block) { + intptr_t count = 0; + MlirOperation op = mlirBlockGetFirstOperation(block); + while (!mlirOperationIsNull(op)) { + op = mlirOperationGetNextInBlock(op); + count++; + } + return count; +} + +/// Tests clone APIs. +int testClone() { + fprintf(stderr, "@testClone\n"); + // CHECK-LABEL: @testClone + + MlirContext ctx = mlirContextCreate(); + MlirLocation loc = mlirLocationUnknownGet(ctx); + MlirModule module = mlirModuleCreateEmpty(loc); + MlirBlock body = mlirModuleGetBody(module); + MlirType indexType = mlirIndexTypeGet(ctx); + + // Create a constant + MlirAttribute indexOneLiteral = + mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("1 : index")); + MlirNamedAttribute indexOneValueAttr = mlirNamedAttributeGet( + mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")), + indexOneLiteral); + MlirOperationState constOneState = mlirOperationStateGet( + mlirStringRefCreateFromCString("std.constant"), loc); + mlirOperationStateAddResults(&constOneState, 1, &indexType); + mlirOperationStateAddAttributes(&constOneState, 1, &indexOneValueAttr); + MlirOperation constOne = mlirOperationCreate(&constOneState); + + // Replace the original const one with a clone and destroy the original + MlirOperation evilConstOne = mlirOperationClone(constOne); + mlirOperationDestroy(constOne); + constOne = evilConstOne; + + // Check that modifying a module does not modify its clone + MlirModule moduleClone = mlirModuleClone(module); + MlirBlock cloneBody = mlirModuleGetBody(moduleClone); + mlirBlockAppendOwnedOperation(body, constOne); + + fprintf(stderr, "Module operation count: %ld\n", countOperations(body)); + fprintf(stderr, "Clone operation count: %ld\n", countOperations(cloneBody)); + // CHECK: Module operation count: 1 + // CHECK: Clone operation count: 0 + + mlirModuleDestroy(moduleClone); + mlirOperationDestroy(constOne); + mlirContextDestroy(ctx); + + return 0; +} + // Wraps a diagnostic into additional text we can match against. MlirLogicalResult errorHandler(MlirDiagnostic diagnostic, void *userData) { fprintf(stderr, "processing diagnostic (userData: %ld) <<\n", (long)userData); @@ -1698,6 +1753,8 @@ return 10; if (testOperands()) return 11; + if (testClone()) + return 12; mlirContextDestroy(ctx);