diff --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir-c/Pass.h @@ -0,0 +1,90 @@ +/*===-- mlir-c/Pass.h - C API to Pass Management ------------------*- C -*-===*\ +|* *| +|* Part of the LLVM Project, under the Apache License v2.0 with LLVM *| +|* Exceptions. *| +|* See https://llvm.org/LICENSE.txt for license information. *| +|* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception *| +|* *| +|*===----------------------------------------------------------------------===*| +|* *| +|* This header declares the C interface to MLIR pass manager. *| +|* *| +\*===----------------------------------------------------------------------===*/ + +#ifndef MLIR_C_PASS_H +#define MLIR_C_PASS_H + +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/*============================================================================*/ +/** Opaque type declarations. + * + * Types are exposed to C bindings as structs containing opaque pointers. They + * are not supposed to be inspected from C. This allows the underlying + * representation to change without affecting the API users. The use of structs + * instead of typedefs enables some type safety as structs are not implicitly + * convertible to each other. + * + * Instances of these types may or may not own the underlying object. The + * ownership semantics is defined by how an instance of the type was obtained. + */ +/*============================================================================*/ + +#define DEFINE_C_API_STRUCT(name, storage) \ + struct name { \ + storage *ptr; \ + }; \ + typedef struct name name + +DEFINE_C_API_STRUCT(MlirPass, void); +DEFINE_C_API_STRUCT(MlirPassManager, void); +DEFINE_C_API_STRUCT(MlirOpPassManager, void); + +#undef DEFINE_C_API_STRUCT + +/** Create a new top-level PassManager. */ +MlirPassManager mlirPassManagerCreate(MlirContext ctx); + +/** Destroy the provided PassManager. */ +void mlirPassManagerDestroy(MlirPassManager passManager); + +/** Run the provided `passManager` on the given `module`. */ +MlirLogicalResult mlirPassManagerRun(MlirPassManager passManager, + MlirModule module); + +/** Nest an OpPassManager under the top-level PassManager, the nested + * passmanager will only run on operations matching the provided name. + * The returned OpPassManager will be destroyed when the parent is destroyed. + * To further nest more OpPassManager under the newly returned one, see + * `mlirOpPassManagerNest` below. */ +MlirOpPassManager mlirPassManagerGetNestedUnder(MlirPassManager passManager, + MlirStringRef operationName); + +/** Nest an OpPassManager under the provided OpPassManager, the nested + * passmanager will only run on operations matching the provided name. + * The returned OpPassManager will be destroyed when the parent is destroyed. */ +MlirOpPassManager mlirOpPassManagerGetNestedUnder(MlirOpPassManager passManager, + MlirStringRef operationName); + +/** Add a pass and transfer ownership to the provided top-level mlirPassManager. + * If the pass is not a generic operation pass or a ModulePass, a new + * OpPassManager is implicitly nested under the provided PassManager. */ +void mlirPassManagerAddOwnedPass(MlirPassManager passManager, MlirPass pass); + +/** Add a pass and transfer ownership to the provided mlirOpPassManager. If the + * pass is not a generic operation pass or matching the type of the provided + * PassManager, a new OpPassManager is implicitly nested under the provided + * PassManager. */ +void mlirOpPassManagerAddOwnedPass(MlirOpPassManager passManager, + MlirPass pass); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_PASS_H diff --git a/mlir/include/mlir-c/Transforms.h b/mlir/include/mlir-c/Transforms.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir-c/Transforms.h @@ -0,0 +1,20 @@ +/*===-- mlir-c/Transforms.h - Helpers for C API to Core MLIR ------*- C -*-===*\ +|* *| +|* Part of the LLVM Project, under the Apache License v2.0 with LLVM *| +|* Exceptions. *| +|* See https://llvm.org/LICENSE.txt for license information. *| +|* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception *| +|* *| +|*===----------------------------------------------------------------------===*| +|* *| +|* This header declares the registration and creation method for *| +|* transformation passes. *| +|* *| +\*===----------------------------------------------------------------------===*/ + +#ifndef MLIR_C_TRANSFORMS_H +#define MLIR_C_TRANSFORMS_H + +#include "mlir/Transforms/Transforms.capi.h.inc" + +#endif // MLIR_C_TRANSFORMS_H diff --git a/mlir/include/mlir/CAPI/IR.h b/mlir/include/mlir/CAPI/IR.h --- a/mlir/include/mlir/CAPI/IR.h +++ b/mlir/include/mlir/CAPI/IR.h @@ -12,8 +12,8 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_INCLUDE_MLIR_CAPI_IR_H -#define MLIR_INCLUDE_MLIR_CAPI_IR_H +#ifndef MLIR_CAPI_IR_H +#define MLIR_CAPI_IR_H #include "mlir/CAPI/Wrap.h" #include "mlir/IR/Identifier.h" @@ -35,4 +35,4 @@ DEFINE_C_API_METHODS(MlirType, mlir::Type) DEFINE_C_API_METHODS(MlirValue, mlir::Value) -#endif // MLIR_INCLUDE_MLIR_CAPI_IR_H +#endif // MLIR_CAPI_IR_H diff --git a/mlir/include/mlir/CAPI/Pass.h b/mlir/include/mlir/CAPI/Pass.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/CAPI/Pass.h @@ -0,0 +1,28 @@ +//===- IR.h - C API Utils for Core MLIR classes -----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains declarations of implementation details of the C API for +// core MLIR classes. This file should not be included from C++ code other than +// C API implementation nor from C code. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CAPI_PASS_H +#define MLIR_CAPI_PASS_H + +#include "mlir-c/Pass.h" + +#include "mlir/CAPI/Wrap.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" + +DEFINE_C_API_PTR_METHODS(MlirPass, mlir::Pass) +DEFINE_C_API_PTR_METHODS(MlirPassManager, mlir::PassManager) +DEFINE_C_API_PTR_METHODS(MlirOpPassManager, mlir::OpPassManager) + +#endif // MLIR_CAPI_PASS_H diff --git a/mlir/include/mlir/Transforms/CMakeLists.txt b/mlir/include/mlir/Transforms/CMakeLists.txt --- a/mlir/include/mlir/Transforms/CMakeLists.txt +++ b/mlir/include/mlir/Transforms/CMakeLists.txt @@ -1,6 +1,8 @@ set(LLVM_TARGET_DEFINITIONS Passes.td) mlir_tablegen(Passes.h.inc -gen-pass-decls -name Transforms) +mlir_tablegen(Transforms.capi.h.inc -gen-pass-capi-header --prefix Transforms) +mlir_tablegen(Transforms.capi.cpp.inc -gen-pass-capi-impl --prefix Transforms) add_public_tablegen_target(MLIRTransformsPassIncGen) add_mlir_doc(Passes -gen-pass-doc GeneralPasses ./) diff --git a/mlir/lib/CAPI/CMakeLists.txt b/mlir/lib/CAPI/CMakeLists.txt --- a/mlir/lib/CAPI/CMakeLists.txt +++ b/mlir/lib/CAPI/CMakeLists.txt @@ -1,3 +1,4 @@ add_subdirectory(IR) add_subdirectory(Registration) add_subdirectory(Standard) +add_subdirectory(Transforms) diff --git a/mlir/lib/CAPI/IR/CMakeLists.txt b/mlir/lib/CAPI/IR/CMakeLists.txt --- a/mlir/lib/CAPI/IR/CMakeLists.txt +++ b/mlir/lib/CAPI/IR/CMakeLists.txt @@ -4,6 +4,7 @@ AffineMap.cpp Diagnostics.cpp IR.cpp + Pass.cpp StandardAttributes.cpp StandardTypes.cpp Support.cpp diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/CAPI/IR/Pass.cpp @@ -0,0 +1,53 @@ +//===- Pass.cpp - C Interface for General Pass Management APIs ------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Pass.h" + +#include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Pass.h" +#include "mlir/CAPI/Support.h" +#include "mlir/CAPI/Utils.h" +#include "mlir/Pass/PassManager.h" + +using namespace mlir; + +/* ========================================================================== */ +/* PassManager/OpPassManager APIs. */ +/* ========================================================================== */ + +MlirPassManager mlirPassManagerCreate(MlirContext ctx) { + return wrap(new PassManager(unwrap(ctx))); +} + +void mlirPassManagerDestroy(MlirPassManager passManager) { + delete unwrap(passManager); +} + +MlirLogicalResult mlirPassManagerRun(MlirPassManager passManager, + MlirModule module) { + return wrap(unwrap(passManager)->run(unwrap(module))); +} + +MlirOpPassManager mlirPassManagerGetNestedUnder(MlirPassManager passManager, + MlirStringRef operationName) { + return wrap(&unwrap(passManager)->nest(unwrap(operationName))); +} + +MlirOpPassManager mlirOpPassManagerGetNestedUnder(MlirOpPassManager passManager, + MlirStringRef operationName) { + return wrap(&unwrap(passManager)->nest(unwrap(operationName))); +} + +void mlirPassManagerAddOwnedPass(MlirPassManager passManager, MlirPass pass) { + unwrap(passManager)->addPass(std::unique_ptr(unwrap(pass))); +} + +void mlirOpPassManagerAddOwnedPass(MlirOpPassManager passManager, + MlirPass pass) { + unwrap(passManager)->addPass(std::unique_ptr(unwrap(pass))); +} diff --git a/mlir/lib/CAPI/Transforms/CMakeLists.txt b/mlir/lib/CAPI/Transforms/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/CAPI/Transforms/CMakeLists.txt @@ -0,0 +1,10 @@ +add_mlir_library(MLIRCAPITransforms + + Passes.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir-c + + LINK_LIBS PUBLIC + MLIRTransforms + ) diff --git a/mlir/lib/CAPI/Transforms/Passes.cpp b/mlir/lib/CAPI/Transforms/Passes.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/CAPI/Transforms/Passes.cpp @@ -0,0 +1,23 @@ +//===- CAPIPAsses.cpp - C API for Tranformations Passes -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/Passes.h" +#include "mlir/CAPI/Pass.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +#ifdef __cplusplus +extern "C" { +#endif + +#include "mlir/Transforms/Transforms.capi.cpp.inc" + +#ifdef __cplusplus +} +#endif diff --git a/mlir/test/CAPI/CMakeLists.txt b/mlir/test/CAPI/CMakeLists.txt --- a/mlir/test/CAPI/CMakeLists.txt +++ b/mlir/test/CAPI/CMakeLists.txt @@ -1,3 +1,8 @@ +set(LLVM_OPTIONAL_SOURCES + ir.c + pass.c +) + set(LLVM_LINK_COMPONENTS Core Support @@ -15,3 +20,18 @@ MLIRCAPIRegistration MLIRCAPIStandard ${dialect_libs}) + + +add_llvm_executable(mlir-capi-pass-test + pass.c + ) +llvm_update_compile_flags(mlir-capi-pass-test) + +get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) +target_link_libraries(mlir-capi-pass-test + PRIVATE + MLIRCAPIIR + MLIRCAPIRegistration + MLIRCAPIStandard + MLIRCAPITransforms + ${dialect_libs}) diff --git a/mlir/test/CAPI/pass.c b/mlir/test/CAPI/pass.c new file mode 100644 --- /dev/null +++ b/mlir/test/CAPI/pass.c @@ -0,0 +1,124 @@ +/*===- pass.c - Simple test of C APIs -------------------------------------===*\ +|* *| +|* Part of the LLVM Project, under the Apache License v2.0 with LLVM *| +|* Exceptions. *| +|* See https://llvm.org/LICENSE.txt for license information. *| +|* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception *| +|* *| +\*===----------------------------------------------------------------------===*/ + +/* RUN: mlir-capi-pass-test 2>&1 | FileCheck %s + */ + +#include "mlir-c/Pass.h" +#include "mlir-c/IR.h" +#include "mlir-c/Registration.h" +#include "mlir-c/Transforms.h" + +#include +#include +#include +#include +#include + +void testRunPassOnModule() { + MlirContext ctx = mlirContextCreate(); + mlirRegisterAllDialects(ctx); + + MlirModule module = + mlirModuleCreateParse(ctx, + // clang-format off +"func @foo(%arg0 : i32) -> i32 { \n" +" %res = addi %arg0, %arg0 : i32 \n" +" return %res : i32 \n" +"}"); + // clang-format on + if (mlirModuleIsNull(module)) + exit(EXIT_FAILURE); + + // Run the print-op-stats pass on the top-level module: + // CHECK-LABEL: Operations encountered: + // CHECK: func , 1 + // CHECK: module_terminator , 1 + // CHECK: std.addi , 1 + // CHECK: std.return , 1 + { + MlirPassManager pm = mlirPassManagerCreate(ctx); + MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats(); + mlirPassManagerAddOwnedPass(pm, printOpStatPass); + MlirLogicalResult success = mlirPassManagerRun(pm, module); + if (mlirLogicalResultIsFailure(success)) + exit(EXIT_FAILURE); + mlirPassManagerDestroy(pm); + } + mlirModuleDestroy(module); + mlirContextDestroy(ctx); +} + +void testRunPassOnNestedModule() { + MlirContext ctx = mlirContextCreate(); + mlirRegisterAllDialects(ctx); + + MlirModule module = + mlirModuleCreateParse(ctx, + // clang-format off +"func @foo(%arg0 : i32) -> i32 { \n" +" %res = addi %arg0, %arg0 : i32 \n" +" return %res : i32 \n" +"} \n" +"module { \n" +" func @bar(%arg0 : f32) -> f32 { \n" +" %res = addf %arg0, %arg0 : f32 \n" +" return %res : f32 \n" +" } \n" +"}"); + // clang-format on + if (mlirModuleIsNull(module)) + exit(1); + + // Run the print-op-stats pass on functions under the top-level module: + // CHECK-LABEL: Operations encountered: + // CHECK-NOT: module_terminator + // CHECK: func , 1 + // CHECK: std.addi , 1 + // CHECK: std.return , 1 + { + MlirPassManager pm = mlirPassManagerCreate(ctx); + MlirOpPassManager nestedFuncPm = mlirPassManagerGetNestedUnder( + pm, mlirStringRefCreateFromCString("func")); + MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats(); + mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass); + MlirLogicalResult success = mlirPassManagerRun(pm, module); + if (mlirLogicalResultIsFailure(success)) + exit(2); + mlirPassManagerDestroy(pm); + } + // Run the print-op-stats pass on functions under the nested module: + // CHECK-LABEL: Operations encountered: + // CHECK-NOT: module_terminator + // CHECK: func , 1 + // CHECK: std.addf , 1 + // CHECK: std.return , 1 + { + MlirPassManager pm = mlirPassManagerCreate(ctx); + MlirOpPassManager nestedModulePm = mlirPassManagerGetNestedUnder( + pm, mlirStringRefCreateFromCString("module")); + MlirOpPassManager nestedFuncPm = mlirOpPassManagerGetNestedUnder( + nestedModulePm, mlirStringRefCreateFromCString("func")); + MlirPass printOpStatPass = mlirCreateTransformsPrintOpStats(); + mlirOpPassManagerAddOwnedPass(nestedFuncPm, printOpStatPass); + MlirLogicalResult success = mlirPassManagerRun(pm, module); + if (mlirLogicalResultIsFailure(success)) + exit(2); + mlirPassManagerDestroy(pm); + } + + mlirModuleDestroy(module); + mlirContextDestroy(ctx); +} + +int main() { + testRunPassOnModule(); + testRunPassOnNestedModule(); + return 0; +} diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt --- a/mlir/test/CMakeLists.txt +++ b/mlir/test/CMakeLists.txt @@ -43,6 +43,7 @@ FileCheck count not MLIRUnitTests mlir-capi-ir-test + mlir-capi-pass-test mlir-cpu-runner mlir-edsc-builder-api-test mlir-linalg-ods-gen diff --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py --- a/mlir/test/lit.cfg.py +++ b/mlir/test/lit.cfg.py @@ -59,6 +59,7 @@ 'mlir-tblgen', 'mlir-translate', 'mlir-capi-ir-test', + 'mlir-capi-pass-test', 'mlir-edsc-builder-api-test', ] diff --git a/mlir/tools/mlir-tblgen/CMakeLists.txt b/mlir/tools/mlir-tblgen/CMakeLists.txt --- a/mlir/tools/mlir-tblgen/CMakeLists.txt +++ b/mlir/tools/mlir-tblgen/CMakeLists.txt @@ -15,8 +15,9 @@ OpFormatGen.cpp OpInterfacesGen.cpp OpenMPCommonGen.cpp - PassGen.cpp + PassCAPIGen.cpp PassDocGen.cpp + PassGen.cpp RewriterGen.cpp SPIRVUtilsGen.cpp StructsGen.cpp diff --git a/mlir/tools/mlir-tblgen/PassCAPIGen.cpp b/mlir/tools/mlir-tblgen/PassCAPIGen.cpp new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-tblgen/PassCAPIGen.cpp @@ -0,0 +1,93 @@ +//===- Pass.cpp - MLIR pass registration generator ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// PassCAPIGen uses the description of passes to generate C API for the passes. +// +//===----------------------------------------------------------------------===// + +#include "mlir/TableGen/GenInfo.h" +#include "mlir/TableGen/Pass.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/TableGen/Error.h" +#include "llvm/TableGen/Record.h" + +using namespace mlir; +using namespace mlir::tblgen; + +static llvm::cl::OptionCategory + passGenCat("Options for -gen-pass-capi-header and -gen-pass-capi-impl"); +static llvm::cl::opt + groupName("prefix", + llvm::cl::desc("The prefix to use for this group of passes. The " + "form will be mlirCreate, the " + "prefix can avoid conflicts across libraries."), + llvm::cl::cat(passGenCat)); + +const char *const passDecl = R"( +/* Create {0} Pass. */ +MlirPass mlirCreate{0}{1}(); + +)"; + +const char *const fileHeader = R"( +/* Autogenerated by mlir-tblgen; don't manually edit. */ + +#include "mlir-c/Pass.h" + +#ifdef __cplusplus +extern "C" { +#endif + +)"; + +const char *const fileFooter = R"( + +#ifdef __cplusplus +} +#endif +)"; + +/// Emit TODO +static bool emitCAPIHeader(const llvm::RecordKeeper &records, raw_ostream &os) { + os << fileHeader; + for (const auto *def : records.getAllDerivedDefinitions("PassBase")) { + Pass pass(def); + StringRef defName = pass.getDef()->getName(); + os << llvm::formatv(passDecl, groupName, defName); + } + os << fileFooter; + return false; +} + +const char *const passCreateDef = R"( +MlirPass mlirCreate{0}{1}() { + return wrap({2}.release()); +} + +)"; + +static bool emitCAPIImpl(const llvm::RecordKeeper &records, raw_ostream &os) { + os << "/* Autogenerated by mlir-tblgen; don't manually edit. */"; + for (const auto *def : records.getAllDerivedDefinitions("PassBase")) { + Pass pass(def); + StringRef defName = pass.getDef()->getName(); + os << llvm::formatv(passCreateDef, groupName, defName, + pass.getConstructor()); + } + return false; +} + +static mlir::GenRegistration genCAPIHeader("gen-pass-capi-header", + "Generate pass C API header", + &emitCAPIHeader); + +static mlir::GenRegistration genCAPIImpl("gen-pass-capi-impl", + "Generate pass C API implementation", + &emitCAPIImpl);