diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -761,6 +761,10 @@ // TypeID API. //===----------------------------------------------------------------------===// +/// `ptr` must be unique to a type valid for the duration of the returned type +/// id's usage +MLIR_CAPI_EXPORTED MlirTypeID mlirTypeIDCreate(const void *ptr); + /// Checks whether a type id is null. static inline bool mlirTypeIDIsNull(MlirTypeID typeID) { return !typeID.ptr; } diff --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h --- a/mlir/include/mlir-c/Pass.h +++ b/mlir/include/mlir-c/Pass.h @@ -15,6 +15,7 @@ #define MLIR_C_PASS_H #include "mlir-c/IR.h" +#include "mlir-c/Registration.h" #include "mlir-c/Support.h" #ifdef __cplusplus @@ -46,6 +47,10 @@ #undef DEFINE_C_API_STRUCT +//===----------------------------------------------------------------------===// +// PassManager/OpPassManager APIs. +//===----------------------------------------------------------------------===// + /// Create a new top-level PassManager. MLIR_CAPI_EXPORTED MlirPassManager mlirPassManagerCreate(MlirContext ctx); @@ -112,6 +117,33 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirParsePassPipeline(MlirOpPassManager passManager, MlirStringRef pipeline); +//===----------------------------------------------------------------------===// +// Dynamic Pass API. +//===----------------------------------------------------------------------===// + +/// Structure of dynamic `MlirPass` callbacks +struct MlirDynamicPassCallbacks { + /// Called when the pass is created + void (*initialize)(void *userData); + /// Called when the pass is destroyed + void (*deinitialize)(void *userData); + /// Called when the pass is cloned + void *(*clone)(void *userData); + /// Called when the pass is run + MlirLogicalResult (*run)(MlirOperation op, void *userData); +}; +typedef struct MlirDynamicPassCallbacks MlirDynamicPassCallbacks; + +/// Creates a dynamic `MlirPass` that calls the supplied `callbacks` using the +/// supplied `userData`. If `opName`'s count is 0, the pass is a generic +/// operation pass. Otherwise it is an operation pass specific to the specified +/// pass name. +MLIR_CAPI_EXPORTED MlirPass mlirCreateDynamicPass( + MlirTypeID passID, MlirStringRef name, MlirStringRef argument, + MlirStringRef description, MlirStringRef opName, + intptr_t nDependentDialects, MlirDialectHandle *dependentDialects, + MlirDynamicPassCallbacks callbacks, void *userData); + #ifdef __cplusplus } #endif diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -791,6 +791,13 @@ // TypeID API. //===----------------------------------------------------------------------===// +MlirTypeID mlirTypeIDCreate(const void *ptr) { + // This is essentially a no-op that returns back `ptr`, but by going through + // the `TypeID` functions we can get compiler errors in case the `TypeID` + // api/representation changes + return wrap(TypeID::getFromOpaquePointer(ptr)); +} + bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2) { return unwrap(typeID1) == unwrap(typeID2); } diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp --- a/mlir/lib/CAPI/IR/Pass.cpp +++ b/mlir/lib/CAPI/IR/Pass.cpp @@ -77,3 +77,83 @@ // stream and redirect to a diagnostic. return wrap(mlir::parsePassPipeline(unwrap(pipeline), *unwrap(passManager))); } + +//===----------------------------------------------------------------------===// +// Dynamic Pass API. +//===----------------------------------------------------------------------===// + +namespace { +/// This pass class wraps dynamic passes defined in other languages using the +/// MLIR C-interface +class DynamicPass : public Pass { + +public: + DynamicPass(TypeID passID, StringRef name, StringRef argument, + StringRef description, Optional opName, + ArrayRef dependentDialects, + MlirDynamicPassCallbacks callbacks, void *userData) + : Pass(passID, opName), id(passID), name(name), argument(argument), + description(description), dependentDialects(dependentDialects), + callbacks(callbacks), userData(userData) { + callbacks.initialize(userData); + } + + ~DynamicPass() override { callbacks.deinitialize(userData); } + + StringRef getName() const override { return name; } + StringRef getArgument() const override { return argument; } + StringRef getDescription() const override { return description; } + + void getDependentDialects(DialectRegistry ®istry) const override { + auto cRegistry = wrap(®istry); + for (auto dialect : dependentDialects) { + mlirDialectHandleInsertDialect(dialect, cRegistry); + } + } + +protected: + bool canScheduleOn(RegisteredOperationName opName) const override { + if (auto specifiedOpName = getOpName()) { + return opName.getStringRef() == specifiedOpName; + } + return true; + } + + void runOnOperation() override { + auto result = callbacks.run(wrap(getOperation()), userData); + if (mlirLogicalResultIsFailure(result)) { + signalPassFailure(); + } + } + + std::unique_ptr clonePass() const override { + auto *clonedUserData = callbacks.clone(userData); + return std::make_unique(id, name, argument, description, + getOpName(), dependentDialects, + callbacks, clonedUserData); + } + +private: + TypeID id; + std::string name; + std::string argument; + std::string description; + std::vector dependentDialects; + MlirDynamicPassCallbacks callbacks; + void *userData; +}; +} // namespace + +MlirPass mlirCreateDynamicPass(MlirTypeID passID, MlirStringRef name, + MlirStringRef argument, + MlirStringRef description, MlirStringRef opName, + intptr_t nDependentDialects, + MlirDialectHandle *dependentDialects, + MlirDynamicPassCallbacks callbacks, + void *userData) { + return wrap(new DynamicPass( + unwrap(passID), unwrap(name), unwrap(argument), unwrap(description), + opName.length > 0 ? Optional(unwrap(opName)) : None, + {dependentDialects, static_cast(nDependentDialects)}, callbacks, + userData)); +} diff --git a/mlir/test/CAPI/CMakeLists.txt b/mlir/test/CAPI/CMakeLists.txt --- a/mlir/test/CAPI/CMakeLists.txt +++ b/mlir/test/CAPI/CMakeLists.txt @@ -49,6 +49,7 @@ _add_capi_test_executable(mlir-capi-pass-test pass.c LINK_LIBS PRIVATE + MLIRCAPIFunc MLIRCAPIIR MLIRCAPIRegistration MLIRCAPITransforms diff --git a/mlir/test/CAPI/pass.c b/mlir/test/CAPI/pass.c --- a/mlir/test/CAPI/pass.c +++ b/mlir/test/CAPI/pass.c @@ -11,6 +11,7 @@ */ #include "mlir-c/Pass.h" +#include "mlir-c/Dialect/Func.h" #include "mlir-c/IR.h" #include "mlir-c/Registration.h" #include "mlir-c/Transforms.h" @@ -165,26 +166,30 @@ // Try parse a pipeline. MlirLogicalResult status = mlirParsePassPipeline( mlirPassManagerGetAsOpPassManager(pm), - mlirStringRefCreateFromCString( - "builtin.module(builtin.func(print-op-stats), builtin.func(print-op-stats))")); + mlirStringRefCreateFromCString("builtin.module(builtin.func(print-op-" + "stats), builtin.func(print-op-stats))")); // Expect a failure, we haven't registered the print-op-stats pass yet. if (mlirLogicalResultIsSuccess(status)) { - fprintf(stderr, "Unexpected success parsing pipeline without registering the pass\n"); + fprintf( + stderr, + "Unexpected success parsing pipeline without registering the pass\n"); exit(EXIT_FAILURE); } // Try again after registrating the pass. mlirRegisterTransformsPrintOpStats(); status = mlirParsePassPipeline( mlirPassManagerGetAsOpPassManager(pm), - mlirStringRefCreateFromCString( - "builtin.module(builtin.func(print-op-stats), builtin.func(print-op-stats))")); + mlirStringRefCreateFromCString("builtin.module(builtin.func(print-op-" + "stats), builtin.func(print-op-stats))")); // Expect a failure, we haven't registered the print-op-stats pass yet. if (mlirLogicalResultIsFailure(status)) { - fprintf(stderr, "Unexpected failure parsing pipeline after registering the pass\n"); + fprintf(stderr, + "Unexpected failure parsing pipeline after registering the pass\n"); exit(EXIT_FAILURE); } - // CHECK: Round-trip: builtin.module(builtin.func(print-op-stats), builtin.func(print-op-stats)) + // CHECK: Round-trip: builtin.module(builtin.func(print-op-stats), + // builtin.func(print-op-stats)) fprintf(stderr, "Round-trip: "); mlirPrintPassPipeline(mlirPassManagerGetAsOpPassManager(pm), printToStderr, NULL); @@ -193,10 +198,217 @@ mlirContextDestroy(ctx); } +struct TestDynamicPassUserData { + int initializeCallCount; + int deinitializeCallCount; + int cloneCallCount; + int runCallCount; +}; +typedef struct TestDynamicPassUserData TestDynamicPassUserData; + +void testInitializeDynamicPass(void *userData) { + ++((TestDynamicPassUserData *)userData)->initializeCallCount; +} + +void testDeinitializeDynamicPass(void *userData) { + ++((TestDynamicPassUserData *)userData)->deinitializeCallCount; +} + +void *testCloneDynamicPass(void *userData) { + ++((TestDynamicPassUserData *)userData)->cloneCallCount; + return userData; +} + +MlirLogicalResult testRunDynamicPass(MlirOperation op, void *userData) { + ++((TestDynamicPassUserData *)userData)->runCallCount; + return mlirLogicalResultSuccess(); +} + +MlirLogicalResult testRunDynamicFuncPass(MlirOperation op, void *userData) { + ++((TestDynamicPassUserData *)userData)->runCallCount; + MlirStringRef opName = mlirIdentifierStr(mlirOperationGetName(op)); + if (mlirStringRefEqual(opName, + mlirStringRefCreateFromCString("builtin.func"))) { + return mlirLogicalResultSuccess(); + } + return mlirLogicalResultFailure(); +} + +MlirLogicalResult testRunFailingDynamicPass(MlirOperation op, void *userData) { + ++((TestDynamicPassUserData *)userData)->runCallCount; + return mlirLogicalResultFailure(); +} + +MlirDynamicPassCallbacks makeTestDynamicPassCallbacks( + MlirLogicalResult (*runPass)(MlirOperation op, void *userData)) { + return (MlirDynamicPassCallbacks){testInitializeDynamicPass, + testDeinitializeDynamicPass, + testCloneDynamicPass, runPass}; +} + +void testDynamicPass() { + MlirContext ctx = mlirContextCreate(); + mlirRegisterAllDialects(ctx); + + MlirModule module = mlirModuleCreateParse( + ctx, + // clang-format off + mlirStringRefCreateFromCString( +"func @foo(%arg0 : i32) -> i32 { \n" +" %res = arith.addi %arg0, %arg0 : i32 \n" +" return %res : i32 \n" +"}")); + // clang-format on + if (mlirModuleIsNull(module)) { + fprintf(stderr, "Unexpected failure parsing module.\n"); + exit(EXIT_FAILURE); + } + + MlirStringRef description = mlirStringRefCreateFromCString(""); + MlirStringRef emptyOpName = mlirStringRefCreateFromCString(""); + + // Run a generic pass + { + static int typeIDStorage; + MlirTypeID passID = mlirTypeIDCreate(&typeIDStorage); + MlirStringRef name = mlirStringRefCreateFromCString("TestDynamicPass"); + MlirStringRef argument = + mlirStringRefCreateFromCString("test-dynamic-pass"); + TestDynamicPassUserData userData = {0}; + + MlirPass dynamicPass = mlirCreateDynamicPass( + passID, name, argument, description, emptyOpName, 0, NULL, + makeTestDynamicPassCallbacks(testRunDynamicPass), &userData); + + if (userData.initializeCallCount != 1) { + fprintf(stderr, "Unexpected initializeCallCount to be 1\n"); + exit(EXIT_FAILURE); + } + + MlirPassManager pm = mlirPassManagerCreate(ctx); + mlirPassManagerAddOwnedPass(pm, dynamicPass); + MlirLogicalResult success = mlirPassManagerRun(pm, module); + if (mlirLogicalResultIsFailure(success)) { + fprintf(stderr, "Unexpected failure running dynamic pass.\n"); + exit(EXIT_FAILURE); + } + + if (userData.runCallCount != 1) { + fprintf(stderr, "Unexpected runCallCount to be 1\n"); + exit(EXIT_FAILURE); + } + + mlirPassManagerDestroy(pm); + + if (userData.deinitializeCallCount != userData.initializeCallCount) { + fprintf(stderr, "Unexpected deinitializeCallCount to be equal to " + "initializeCallCount\n"); + exit(EXIT_FAILURE); + } + } + + // Run a func operation pass + { + static int typeIDStorage; + MlirTypeID passID = mlirTypeIDCreate(&typeIDStorage); + MlirStringRef name = mlirStringRefCreateFromCString("TestDynamicFuncPass"); + MlirStringRef argument = + mlirStringRefCreateFromCString("test-dynamic-func-pass"); + TestDynamicPassUserData userData = {0}; + MlirDialectHandle funcHandle = mlirGetDialectHandle__func__(); + MlirStringRef funcOpName = mlirStringRefCreateFromCString("builtin.func"); + + MlirPass dynamicPass = mlirCreateDynamicPass( + passID, name, argument, description, funcOpName, 1, &funcHandle, + makeTestDynamicPassCallbacks(testRunDynamicFuncPass), &userData); + + if (userData.initializeCallCount != 1) { + fprintf(stderr, "Unexpected initializeCallCount to be 1\n"); + exit(EXIT_FAILURE); + } + + MlirPassManager pm = mlirPassManagerCreate(ctx); + MlirOpPassManager nestedFuncPm = + mlirPassManagerGetNestedUnder(pm, funcOpName); + mlirOpPassManagerAddOwnedPass(nestedFuncPm, dynamicPass); + MlirLogicalResult success = mlirPassManagerRun(pm, module); + if (mlirLogicalResultIsFailure(success)) { + fprintf(stderr, "Unexpected failure running dynamic operation pass.\n"); + exit(EXIT_FAILURE); + } + + // Since this is a nested pass, it can be cloned and run in parallel + if (userData.cloneCallCount != userData.initializeCallCount - 1) { + fprintf(stderr, "Unexpected initializeCallCount to be 1\n"); + exit(EXIT_FAILURE); + } + + // The pass should only be run once this there is only one func op + if (userData.runCallCount != 1) { + fprintf(stderr, "Unexpected runCallCount to be 1\n"); + exit(EXIT_FAILURE); + } + + mlirPassManagerDestroy(pm); + + if (userData.deinitializeCallCount != userData.initializeCallCount) { + fprintf(stderr, "Unexpected deinitializeCallCount to be equal to " + "initializeCallCount\n"); + exit(EXIT_FAILURE); + } + } + + // Run a failing pass + { + static int typeIDStorage; + MlirTypeID passID = mlirTypeIDCreate(&typeIDStorage); + MlirStringRef name = + mlirStringRefCreateFromCString("TestDynamicFailingPass"); + MlirStringRef argument = + mlirStringRefCreateFromCString("test-dynamic-failing-pass"); + TestDynamicPassUserData userData = {0}; + + MlirPass dynamicPass = mlirCreateDynamicPass( + passID, name, argument, description, emptyOpName, 0, NULL, + makeTestDynamicPassCallbacks(testRunFailingDynamicPass), &userData); + + if (userData.initializeCallCount != 1) { + fprintf(stderr, "Unexpected initializeCallCount to be 1\n"); + exit(EXIT_FAILURE); + } + + MlirPassManager pm = mlirPassManagerCreate(ctx); + mlirPassManagerAddOwnedPass(pm, dynamicPass); + MlirLogicalResult success = mlirPassManagerRun(pm, module); + if (mlirLogicalResultIsSuccess(success)) { + fprintf( + stderr, + "Expected failure running pass manager on failing dynamic pass.\n"); + exit(EXIT_FAILURE); + } + + if (userData.runCallCount != 1) { + fprintf(stderr, "Unexpected runCallCount to be 1\n"); + exit(EXIT_FAILURE); + } + + mlirPassManagerDestroy(pm); + + if (userData.deinitializeCallCount != userData.initializeCallCount) { + fprintf(stderr, "Unexpected deinitializeCallCount to be equal to " + "initializeCallCount\n"); + exit(EXIT_FAILURE); + } + } + + mlirContextDestroy(ctx); +} + int main() { testRunPassOnModule(); testRunPassOnNestedModule(); testPrintPassPipeline(); testParsePassPipeline(); + testDynamicPass(); return 0; }