diff --git a/mlir/include/mlir-c/PatternMatch.h b/mlir/include/mlir-c/PatternMatch.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir-c/PatternMatch.h @@ -0,0 +1,72 @@ +//===-- mlir-c/PatternMatch.h - C API to Pattern Matching-----------*- C +//-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header declares the C interface to MLIR Pattern Matcher. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_PATTERN_MATCH_H +#define MLIR_C_PATTERN_MATCH_H + +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" + +#ifdef __cplusplus +extern "C" { +#endif + +//===----------------------------------------------------------------------===// +// Opaque type declarations. +// +// Types are exposed to C bindings as structs containing opaque pointers. They +// are not supposed to be inspected from C. This allows the underlying +// representation to change without affecting the API users. The use of structs +// instead of typedefs enables some type safety as structs are not implicitly +// convertible to each other. +// +// Instances of these types may or may not own the underlying object. The +// ownership semantics is defined by how an instance of the type was obtained. +//===----------------------------------------------------------------------===// + +#define DEFINE_C_API_STRUCT(name, storage) \ + struct name { \ + storage *ptr; \ + }; \ + typedef struct name name + +DEFINE_C_API_STRUCT(MlirPDLPatternModule, void); +DEFINE_C_API_STRUCT(MlirRewritePatternSet, void); + +#undef DEFINE_C_API_STRUCT + +/// Creates a PDLPatternModule from the given Module. +MLIR_CAPI_EXPORTED MlirPDLPatternModule mlirPDLPatternGet(MlirModule module); + +/// Creates a new RewritePatternSet. +MLIR_CAPI_EXPORTED MlirRewritePatternSet +mlirRewritePatternSetGet(MlirContext context); + +/// Takes a PDL pattern owned by the caller and inserts it in the given pattern set. +MLIR_CAPI_EXPORTED MlirRewritePatternSet mlirPatternSetAddOwnedPDLPattern( + MlirRewritePatternSet patterns, MlirPDLPatternModule pdlPattern); + +/// Takes a pattern set owned by the caller and greedily applies it on the given region. +MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyOwnedPatternsGreedilyOnRegion( + MlirRegion region, MlirRewritePatternSet patterns); + +/// Takes a pattern set owned by the caller and greedily applies it on the given operation. +MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyOwnedPatternsGreedilyOnOperation( + MlirOperation op, MlirRewritePatternSet patterns); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_PATTERN_MATCH_H diff --git a/mlir/lib/CAPI/IR/CMakeLists.txt b/mlir/lib/CAPI/IR/CMakeLists.txt --- a/mlir/lib/CAPI/IR/CMakeLists.txt +++ b/mlir/lib/CAPI/IR/CMakeLists.txt @@ -9,6 +9,7 @@ IntegerSet.cpp IR.cpp Pass.cpp + PatternMatch.cpp Support.cpp LINK_LIBS PUBLIC diff --git a/mlir/lib/CAPI/IR/PatternMatch.cpp b/mlir/lib/CAPI/IR/PatternMatch.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/CAPI/IR/PatternMatch.cpp @@ -0,0 +1,50 @@ +//===- PatternMatch.cpp - C Interface MLIR Pattern Matcher APIs +//------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/PatternMatch.h" + +#include "mlir/CAPI/Pass.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/CAPI/Wrap.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Rewrite/PatternApplicator.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; + +DEFINE_C_API_PTR_METHODS(MlirPDLPatternModule, PDLPatternModule) +DEFINE_C_API_PTR_METHODS(MlirRewritePatternSet, RewritePatternSet) + +MlirPDLPatternModule mlirPDLPatternGet(MlirModule module) { + return wrap(new PDLPatternModule(unwrap(module))); +} + +MlirRewritePatternSet mlirRewritePatternSetGet(MlirContext context) { + return wrap(new RewritePatternSet(unwrap(context))); +} + +MlirRewritePatternSet +mlirPatternSetAddOwnedPDLPattern(MlirRewritePatternSet patterns, + MlirPDLPatternModule pdlPattern) { + return wrap(&unwrap(patterns)->add(std::move(*(unwrap(pdlPattern))))); +} + +MlirLogicalResult +mlirApplyOwnedPatternsGreedilyOnRegion(MlirRegion region, + MlirRewritePatternSet patterns) { + return wrap(applyPatternsAndFoldGreedily(*unwrap(region), + std::move(*unwrap(patterns)))); +} + +MlirLogicalResult +mlirApplyOwnedPatternsGreedilyOnOperation(MlirOperation op, + MlirRewritePatternSet patterns) { + return wrap( + applyPatternsAndFoldGreedily(unwrap(op), std::move(*unwrap(patterns)))); +}