diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -761,6 +761,10 @@ // TypeID API. //===----------------------------------------------------------------------===// +/// `ptr` must be unique to a type valid for the duration of the returned type +/// id's usage +MLIR_CAPI_EXPORTED MlirTypeID mlirTypeIDCreate(const void *ptr); + /// Checks whether a type id is null. static inline bool mlirTypeIDIsNull(MlirTypeID typeID) { return !typeID.ptr; } diff --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h --- a/mlir/include/mlir-c/Pass.h +++ b/mlir/include/mlir-c/Pass.h @@ -15,6 +15,7 @@ #define MLIR_C_PASS_H #include "mlir-c/IR.h" +#include "mlir-c/Registration.h" #include "mlir-c/Support.h" #ifdef __cplusplus @@ -46,6 +47,10 @@ #undef DEFINE_C_API_STRUCT +//===----------------------------------------------------------------------===// +// PassManager/OpPassManager APIs. +//===----------------------------------------------------------------------===// + /// Create a new top-level PassManager. MLIR_CAPI_EXPORTED MlirPassManager mlirPassManagerCreate(MlirContext ctx); @@ -112,6 +117,51 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirParsePassPipeline(MlirOpPassManager passManager, MlirStringRef pipeline); +//===----------------------------------------------------------------------===// +// External Pass API. +// +// This API allows to define passes outside of MLIR, not necessarily in +// C++, and register them with the MLIR pass management infrastructure. +// +//===----------------------------------------------------------------------===// + +/// Structure of external `MlirPass` callbacks. +/// All callbacks are required to be set unless otherwise specified. +struct MlirExternalPassCallbacks { + /// This callback is called from the pass is created. + /// This is analogous to a C++ pass constructor. + void (*construct)(void *userData); + + /// This callback is called when the pass is destroyed + /// This is analogous to a C++ pass destructor. + void (*destruct)(void *userData); + + /// This callback is optional. + /// The callback is called before the pass is run, allowing a chance to + /// initialize any complex state necessary for running the pass. + /// See Pass::initialize(MLIRContext *). + MlirLogicalResult (*initialize)(MlirContext ctx, void *userData); + + /// This callback is called when the pass is cloned. + /// See Pass::clonePass(). + void *(*clone)(void *userData); + + /// This callback is called when the pass is run. + /// See Pass::runOnOperation(). + MlirLogicalResult (*run)(MlirOperation op, void *userData); +}; +typedef struct MlirExternalPassCallbacks MlirExternalPassCallbacks; + +/// Creates an external `MlirPass` that calls the supplied `callbacks` using the +/// supplied `userData`. If `opName`'s count is 0, the pass is a generic +/// operation pass. Otherwise it is an operation pass specific to the specified +/// pass name. +MLIR_CAPI_EXPORTED MlirPass mlirCreateExternalPass( + MlirTypeID passID, MlirStringRef name, MlirStringRef argument, + MlirStringRef description, MlirStringRef opName, + intptr_t nDependentDialects, MlirDialectHandle *dependentDialects, + MlirExternalPassCallbacks callbacks, void *userData); + #ifdef __cplusplus } #endif diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -791,6 +791,13 @@ // TypeID API. //===----------------------------------------------------------------------===// +MlirTypeID mlirTypeIDCreate(const void *ptr) { + // This is essentially a no-op that returns back `ptr`, but by going through + // the `TypeID` functions we can get compiler errors in case the `TypeID` + // api/representation changes + return wrap(TypeID::getFromOpaquePointer(ptr)); +} + bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2) { return unwrap(typeID1) == unwrap(typeID2); } diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp --- a/mlir/lib/CAPI/IR/Pass.cpp +++ b/mlir/lib/CAPI/IR/Pass.cpp @@ -77,3 +77,87 @@ // stream and redirect to a diagnostic. return wrap(mlir::parsePassPipeline(unwrap(pipeline), *unwrap(passManager))); } + +//===----------------------------------------------------------------------===// +// External Pass API. +//===----------------------------------------------------------------------===// + +namespace { +/// This pass class wraps external passes defined in other languages using the +/// MLIR C-interface +class ExternalPass : public Pass { +public: + ExternalPass(TypeID passID, StringRef name, StringRef argument, + StringRef description, Optional opName, + ArrayRef dependentDialects, + MlirExternalPassCallbacks callbacks, void *userData) + : Pass(passID, opName), id(passID), name(name), argument(argument), + description(description), dependentDialects(dependentDialects), + callbacks(callbacks), userData(userData) { + callbacks.construct(userData); + } + + ~ExternalPass() override { callbacks.destruct(userData); } + + StringRef getName() const override { return name; } + StringRef getArgument() const override { return argument; } + StringRef getDescription() const override { return description; } + + void getDependentDialects(DialectRegistry ®istry) const override { + auto cRegistry = wrap(®istry); + for (auto dialect : dependentDialects) + mlirDialectHandleInsertDialect(dialect, cRegistry); + } + +protected: + LogicalResult initialize(MLIRContext *ctx) override { + if (callbacks.initialize) + return unwrap(callbacks.initialize(wrap(ctx), userData)); + return success(); + } + + bool canScheduleOn(RegisteredOperationName opName) const override { + if (auto specifiedOpName = getOpName()) { + return opName.getStringRef() == specifiedOpName; + } + return true; + } + + void runOnOperation() override { + auto result = callbacks.run(wrap(getOperation()), userData); + if (mlirLogicalResultIsFailure(result)) { + signalPassFailure(); + } + } + + std::unique_ptr clonePass() const override { + auto *clonedUserData = callbacks.clone(userData); + return std::make_unique(id, name, argument, description, + getOpName(), dependentDialects, + callbacks, clonedUserData); + } + +private: + TypeID id; + std::string name; + std::string argument; + std::string description; + std::vector dependentDialects; + MlirExternalPassCallbacks callbacks; + void *userData; +}; +} // namespace + +MlirPass mlirCreateExternalPass(MlirTypeID passID, MlirStringRef name, + MlirStringRef argument, + MlirStringRef description, MlirStringRef opName, + intptr_t nDependentDialects, + MlirDialectHandle *dependentDialects, + MlirExternalPassCallbacks callbacks, + void *userData) { + return wrap(new ExternalPass( + unwrap(passID), unwrap(name), unwrap(argument), unwrap(description), + opName.length > 0 ? Optional(unwrap(opName)) : None, + {dependentDialects, static_cast(nDependentDialects)}, callbacks, + userData)); +} diff --git a/mlir/test/CAPI/CMakeLists.txt b/mlir/test/CAPI/CMakeLists.txt --- a/mlir/test/CAPI/CMakeLists.txt +++ b/mlir/test/CAPI/CMakeLists.txt @@ -49,6 +49,7 @@ _add_capi_test_executable(mlir-capi-pass-test pass.c LINK_LIBS PRIVATE + MLIRCAPIFunc MLIRCAPIIR MLIRCAPIRegistration MLIRCAPITransforms diff --git a/mlir/test/CAPI/pass.c b/mlir/test/CAPI/pass.c --- a/mlir/test/CAPI/pass.c +++ b/mlir/test/CAPI/pass.c @@ -11,6 +11,7 @@ */ #include "mlir-c/Pass.h" +#include "mlir-c/Dialect/Func.h" #include "mlir-c/IR.h" #include "mlir-c/Registration.h" #include "mlir-c/Transforms.h" @@ -169,7 +170,9 @@ " func.func(print-op-stats))")); // Expect a failure, we haven't registered the print-op-stats pass yet. if (mlirLogicalResultIsSuccess(status)) { - fprintf(stderr, "Unexpected success parsing pipeline without registering the pass\n"); + fprintf( + stderr, + "Unexpected success parsing pipeline without registering the pass\n"); exit(EXIT_FAILURE); } // Try again after registrating the pass. @@ -180,7 +183,8 @@ " func.func(print-op-stats))")); // Expect a failure, we haven't registered the print-op-stats pass yet. if (mlirLogicalResultIsFailure(status)) { - fprintf(stderr, "Unexpected failure parsing pipeline after registering the pass\n"); + fprintf(stderr, + "Unexpected failure parsing pipeline after registering the pass\n"); exit(EXIT_FAILURE); } @@ -194,10 +198,328 @@ mlirContextDestroy(ctx); } +struct TestExternalPassUserData { + int constructCallCount; + int destructCallCount; + int initializeCallCount; + int cloneCallCount; + int runCallCount; +}; +typedef struct TestExternalPassUserData TestExternalPassUserData; + +void testConstructExternalPass(void *userData) { + ++((TestExternalPassUserData *)userData)->constructCallCount; +} + +void testDestructExternalPass(void *userData) { + ++((TestExternalPassUserData *)userData)->destructCallCount; +} + +MlirLogicalResult testInitializeExternalPass(MlirContext ctx, void *userData) { + ++((TestExternalPassUserData *)userData)->initializeCallCount; + return mlirLogicalResultSuccess(); +} + +MlirLogicalResult testInitializeFailingExternalPass(MlirContext ctx, + void *userData) { + ++((TestExternalPassUserData *)userData)->initializeCallCount; + return mlirLogicalResultFailure(); +} + +void *testCloneExternalPass(void *userData) { + ++((TestExternalPassUserData *)userData)->cloneCallCount; + return userData; +} + +MlirLogicalResult testRunExternalPass(MlirOperation op, void *userData) { + ++((TestExternalPassUserData *)userData)->runCallCount; + return mlirLogicalResultSuccess(); +} + +MlirLogicalResult testRunExternalFuncPass(MlirOperation op, void *userData) { + ++((TestExternalPassUserData *)userData)->runCallCount; + MlirStringRef opName = mlirIdentifierStr(mlirOperationGetName(op)); + if (mlirStringRefEqual(opName, mlirStringRefCreateFromCString("func.func"))) { + return mlirLogicalResultSuccess(); + } + return mlirLogicalResultFailure(); +} + +MlirLogicalResult testRunFailingExternalPass(MlirOperation op, void *userData) { + ++((TestExternalPassUserData *)userData)->runCallCount; + return mlirLogicalResultFailure(); +} + +MlirExternalPassCallbacks makeTestExternalPassCallbacks( + MlirLogicalResult (*initializePass)(MlirContext ctx, void *userData), + MlirLogicalResult (*runPass)(MlirOperation op, void *userData)) { + return (MlirExternalPassCallbacks){testConstructExternalPass, + testDestructExternalPass, initializePass, + testCloneExternalPass, runPass}; +} + +void testExternalPass() { + MlirContext ctx = mlirContextCreate(); + mlirRegisterAllDialects(ctx); + + MlirModule module = mlirModuleCreateParse( + ctx, + // clang-format off + mlirStringRefCreateFromCString( +"func @foo(%arg0 : i32) -> i32 { \n" +" %res = arith.addi %arg0, %arg0 : i32 \n" +" return %res : i32 \n" +"}")); + // clang-format on + if (mlirModuleIsNull(module)) { + fprintf(stderr, "Unexpected failure parsing module.\n"); + exit(EXIT_FAILURE); + } + + MlirStringRef description = mlirStringRefCreateFromCString(""); + MlirStringRef emptyOpName = mlirStringRefCreateFromCString(""); + + // Run a generic pass + { + static int typeIDStorage; + MlirTypeID passID = mlirTypeIDCreate(&typeIDStorage); + MlirStringRef name = mlirStringRefCreateFromCString("TestExternalPass"); + MlirStringRef argument = + mlirStringRefCreateFromCString("test-external-pass"); + TestExternalPassUserData userData = {0}; + + MlirPass externalPass = mlirCreateExternalPass( + passID, name, argument, description, emptyOpName, 0, NULL, + makeTestExternalPassCallbacks(NULL, testRunExternalPass), &userData); + + if (userData.constructCallCount != 1) { + fprintf(stderr, "Expected constructCallCount to be 1\n"); + exit(EXIT_FAILURE); + } + + MlirPassManager pm = mlirPassManagerCreate(ctx); + mlirPassManagerAddOwnedPass(pm, externalPass); + MlirLogicalResult success = mlirPassManagerRun(pm, module); + if (mlirLogicalResultIsFailure(success)) { + fprintf(stderr, "Unexpected failure running external pass.\n"); + exit(EXIT_FAILURE); + } + + if (userData.runCallCount != 1) { + fprintf(stderr, "Expected runCallCount to be 1\n"); + exit(EXIT_FAILURE); + } + + mlirPassManagerDestroy(pm); + + if (userData.destructCallCount != userData.constructCallCount) { + fprintf(stderr, "Expected destructCallCount to be equal to " + "constructCallCount\n"); + exit(EXIT_FAILURE); + } + } + + // Run a func operation pass + { + static int typeIDStorage; + MlirTypeID passID = mlirTypeIDCreate(&typeIDStorage); + MlirStringRef name = mlirStringRefCreateFromCString("TestExternalFuncPass"); + MlirStringRef argument = + mlirStringRefCreateFromCString("test-external-func-pass"); + TestExternalPassUserData userData = {0}; + MlirDialectHandle funcHandle = mlirGetDialectHandle__func__(); + MlirStringRef funcOpName = mlirStringRefCreateFromCString("func.func"); + + MlirPass externalPass = mlirCreateExternalPass( + passID, name, argument, description, funcOpName, 1, &funcHandle, + makeTestExternalPassCallbacks(NULL, testRunExternalFuncPass), + &userData); + + if (userData.constructCallCount != 1) { + fprintf(stderr, "Expected constructCallCount to be 1\n"); + exit(EXIT_FAILURE); + } + + MlirPassManager pm = mlirPassManagerCreate(ctx); + MlirOpPassManager nestedFuncPm = + mlirPassManagerGetNestedUnder(pm, funcOpName); + mlirOpPassManagerAddOwnedPass(nestedFuncPm, externalPass); + MlirLogicalResult success = mlirPassManagerRun(pm, module); + if (mlirLogicalResultIsFailure(success)) { + fprintf(stderr, "Unexpected failure running external operation pass.\n"); + exit(EXIT_FAILURE); + } + + // Since this is a nested pass, it can be cloned and run in parallel + if (userData.cloneCallCount != userData.constructCallCount - 1) { + fprintf(stderr, "Expected constructCallCount to be 1\n"); + exit(EXIT_FAILURE); + } + + // The pass should only be run once this there is only one func op + if (userData.runCallCount != 1) { + fprintf(stderr, "Expected runCallCount to be 1\n"); + exit(EXIT_FAILURE); + } + + mlirPassManagerDestroy(pm); + + if (userData.destructCallCount != userData.constructCallCount) { + fprintf(stderr, "Expected destructCallCount to be equal to " + "constructCallCount\n"); + exit(EXIT_FAILURE); + } + } + + // Run a pass with `initialize` set + { + static int typeIDStorage; + MlirTypeID passID = mlirTypeIDCreate(&typeIDStorage); + MlirStringRef name = mlirStringRefCreateFromCString("TestExternalPass"); + MlirStringRef argument = + mlirStringRefCreateFromCString("test-external-pass"); + TestExternalPassUserData userData = {0}; + + MlirPass externalPass = mlirCreateExternalPass( + passID, name, argument, description, emptyOpName, 0, NULL, + makeTestExternalPassCallbacks(testInitializeExternalPass, + testRunExternalPass), + &userData); + + if (userData.constructCallCount != 1) { + fprintf(stderr, "Expected constructCallCount to be 1\n"); + exit(EXIT_FAILURE); + } + + MlirPassManager pm = mlirPassManagerCreate(ctx); + mlirPassManagerAddOwnedPass(pm, externalPass); + MlirLogicalResult success = mlirPassManagerRun(pm, module); + if (mlirLogicalResultIsFailure(success)) { + fprintf(stderr, "Unexpected failure running external pass.\n"); + exit(EXIT_FAILURE); + } + + if (userData.initializeCallCount != 1) { + fprintf(stderr, "Expected initializeCallCount to be 1\n"); + exit(EXIT_FAILURE); + } + + if (userData.runCallCount != 1) { + fprintf(stderr, "Expected runCallCount to be 1\n"); + exit(EXIT_FAILURE); + } + + mlirPassManagerDestroy(pm); + + if (userData.destructCallCount != userData.constructCallCount) { + fprintf(stderr, "Expected destructCallCount to be equal to " + "constructCallCount\n"); + exit(EXIT_FAILURE); + } + } + + // Run a pass that fails during `initialize` + { + static int typeIDStorage; + MlirTypeID passID = mlirTypeIDCreate(&typeIDStorage); + MlirStringRef name = + mlirStringRefCreateFromCString("TestExternalFailingPass"); + MlirStringRef argument = + mlirStringRefCreateFromCString("test-external-failing-pass"); + TestExternalPassUserData userData = {0}; + + MlirPass externalPass = mlirCreateExternalPass( + passID, name, argument, description, emptyOpName, 0, NULL, + makeTestExternalPassCallbacks(testInitializeFailingExternalPass, + testRunPassOnModule), + &userData); + + if (userData.constructCallCount != 1) { + fprintf(stderr, "Expected constructCallCount to be 1\n"); + exit(EXIT_FAILURE); + } + + MlirPassManager pm = mlirPassManagerCreate(ctx); + mlirPassManagerAddOwnedPass(pm, externalPass); + MlirLogicalResult success = mlirPassManagerRun(pm, module); + if (mlirLogicalResultIsSuccess(success)) { + fprintf( + stderr, + "Expected failure running pass manager on failing external pass.\n"); + exit(EXIT_FAILURE); + } + + if (userData.initializeCallCount != 1) { + fprintf(stderr, "Expected initializeCallCount to be 1\n"); + exit(EXIT_FAILURE); + } + + if (userData.runCallCount != 0) { + fprintf(stderr, "Expected runCallCount to be 0\n"); + exit(EXIT_FAILURE); + } + + mlirPassManagerDestroy(pm); + + if (userData.destructCallCount != userData.constructCallCount) { + fprintf(stderr, "Expected destructCallCount to be equal to " + "constructCallCount\n"); + exit(EXIT_FAILURE); + } + } + + // Run a pass that fails during `run` + { + static int typeIDStorage; + MlirTypeID passID = mlirTypeIDCreate(&typeIDStorage); + MlirStringRef name = + mlirStringRefCreateFromCString("TestExternalFailingPass"); + MlirStringRef argument = + mlirStringRefCreateFromCString("test-external-failing-pass"); + TestExternalPassUserData userData = {0}; + + MlirPass externalPass = mlirCreateExternalPass( + passID, name, argument, description, emptyOpName, 0, NULL, + makeTestExternalPassCallbacks(NULL, testRunFailingExternalPass), + &userData); + + if (userData.constructCallCount != 1) { + fprintf(stderr, "Expected constructCallCount to be 1\n"); + exit(EXIT_FAILURE); + } + + MlirPassManager pm = mlirPassManagerCreate(ctx); + mlirPassManagerAddOwnedPass(pm, externalPass); + MlirLogicalResult success = mlirPassManagerRun(pm, module); + if (mlirLogicalResultIsSuccess(success)) { + fprintf( + stderr, + "Expected failure running pass manager on failing external pass.\n"); + exit(EXIT_FAILURE); + } + + if (userData.runCallCount != 1) { + fprintf(stderr, "Expected runCallCount to be 1\n"); + exit(EXIT_FAILURE); + } + + mlirPassManagerDestroy(pm); + + if (userData.destructCallCount != userData.constructCallCount) { + fprintf(stderr, "Expected destructCallCount to be equal to " + "constructCallCount\n"); + exit(EXIT_FAILURE); + } + } + + mlirContextDestroy(ctx); +} + int main() { testRunPassOnModule(); testRunPassOnNestedModule(); testPrintPassPipeline(); testParsePassPipeline(); + testExternalPass(); return 0; }