diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -326,6 +326,10 @@ /// - Result type inference is enabled and cannot be performed. MLIR_CAPI_EXPORTED MlirOperation mlirOperationCreate(MlirOperationState *state); +/// Creates a deep copy of an operation. The operation is not inserted and +/// ownership is transferred to the caller. +MLIR_CAPI_EXPORTED MlirOperation mlirOperationClone(MlirOperation op); + /// Takes an operation owned by the caller and destroys it. MLIR_CAPI_EXPORTED void mlirOperationDestroy(MlirOperation op); diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -313,6 +313,10 @@ return result; } +MlirOperation mlirOperationClone(MlirOperation op) { + return wrap(unwrap(op)->clone()); +} + void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); } bool mlirOperationEqual(MlirOperation op, MlirOperation other) { diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -1619,6 +1619,38 @@ return 0; } +/// Tests clone APIs. +int testClone() { + fprintf(stderr, "@testClone\n"); + // CHECK-LABEL: @testClone + + MlirContext ctx = mlirContextCreate(); + MlirLocation loc = mlirLocationUnknownGet(ctx); + MlirType indexType = mlirIndexTypeGet(ctx); + MlirStringRef valueStringRef = mlirStringRefCreateFromCString("value"); + + MlirAttribute indexZeroLiteral = + mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("0 : index")); + MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet(mlirIdentifierGet(ctx, valueStringRef), indexZeroLiteral); + MlirOperationState constZeroState = mlirOperationStateGet( + mlirStringRefCreateFromCString("std.constant"), loc); + mlirOperationStateAddResults(&constZeroState, 1, &indexType); + mlirOperationStateAddAttributes(&constZeroState, 1, &indexZeroValueAttr); + MlirOperation constZero = mlirOperationCreate(&constZeroState); + + MlirAttribute indexOneLiteral = + mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("1 : index")); + MlirOperation constOne = mlirOperationClone(constZero); + mlirOperationSetAttributeByName(constOne, valueStringRef, indexOneLiteral); + + mlirOperationPrint(constZero, printToStderr, NULL); + mlirOperationPrint(constOne, printToStderr, NULL); + // CHECK: %0 = "std.constant"() {value = 0 : index} : () -> index + // CHECK: %0 = "std.constant"() {value = 1 : index} : () -> index + + return 0; +} + // Wraps a diagnostic into additional text we can match against. MlirLogicalResult errorHandler(MlirDiagnostic diagnostic, void *userData) { fprintf(stderr, "processing diagnostic (userData: %ld) <<\n", (long)userData); @@ -1698,6 +1730,8 @@ return 10; if (testOperands()) return 11; + if (testClone()) + return 12; mlirContextDestroy(ctx);