diff --git a/mlir/include/mlir-c/PatternMatch.h b/mlir/include/mlir-c/PatternMatch.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir-c/PatternMatch.h @@ -0,0 +1,75 @@ +//===-- mlir-c/PatternMatch.h - C API to Pattern Matching-----------*- C +//-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header declares the C interface to MLIR Pattern Matcher. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_PATTERN_MATCH_H +#define MLIR_C_PATTERN_MATCH_H + +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" + +#ifdef __cplusplus +extern "C" { +#endif + +//===----------------------------------------------------------------------===// +// Opaque type declarations. +// +// Types are exposed to C bindings as structs containing opaque pointers. They +// are not supposed to be inspected from C. This allows the underlying +// representation to change without affecting the API users. The use of structs +// instead of typedefs enables some type safety as structs are not implicitly +// convertible to each other. +// +// Instances of these types may or may not own the underlying object. The +// ownership semantics is defined by how an instance of the type was obtained. +//===----------------------------------------------------------------------===// + +#define DEFINE_C_API_STRUCT(name, storage) \ + struct name { \ + storage *ptr; \ + }; \ + typedef struct name name + +DEFINE_C_API_STRUCT(MlirPDLPatternModule, void); +DEFINE_C_API_STRUCT(MlirRewritePatternSet, void); + +#undef DEFINE_C_API_STRUCT + +/// Creates a PDLPatternModule from the given Module. +MLIR_CAPI_EXPORTED MlirPDLPatternModule mlirPDLPatternGet(MlirModule module); + +/// Creates a new RewritePatternSet. +MLIR_CAPI_EXPORTED MlirRewritePatternSet +mlirRewritePatternSetGet(MlirContext context); + +/// Takes a PDL pattern owned by the caller and inserts it in the given pattern +/// set. +MLIR_CAPI_EXPORTED MlirRewritePatternSet mlirPatternSetAddOwnedPDLPattern( + MlirRewritePatternSet patterns, MlirPDLPatternModule pdlPattern); + +/// Takes a pattern set owned by the caller and greedily applies it on the given +/// region. +MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyOwnedPatternsGreedilyOnRegion( + MlirRegion region, MlirRewritePatternSet patterns); + +/// Takes a pattern set owned by the caller and greedily applies it on the given +/// operation. +MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyOwnedPatternsGreedilyOnOperation( + MlirOperation op, MlirRewritePatternSet patterns); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_PATTERN_MATCH_H diff --git a/mlir/lib/CAPI/IR/CMakeLists.txt b/mlir/lib/CAPI/IR/CMakeLists.txt --- a/mlir/lib/CAPI/IR/CMakeLists.txt +++ b/mlir/lib/CAPI/IR/CMakeLists.txt @@ -9,6 +9,7 @@ IntegerSet.cpp IR.cpp Pass.cpp + PatternMatch.cpp Support.cpp LINK_LIBS PUBLIC diff --git a/mlir/lib/CAPI/IR/PatternMatch.cpp b/mlir/lib/CAPI/IR/PatternMatch.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/CAPI/IR/PatternMatch.cpp @@ -0,0 +1,50 @@ +//===- PatternMatch.cpp - C Interface MLIR Pattern Matcher APIs +//------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/PatternMatch.h" + +#include "mlir/CAPI/Pass.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/CAPI/Wrap.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Rewrite/PatternApplicator.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; + +DEFINE_C_API_PTR_METHODS(MlirPDLPatternModule, PDLPatternModule) +DEFINE_C_API_PTR_METHODS(MlirRewritePatternSet, RewritePatternSet) + +MlirPDLPatternModule mlirPDLPatternGet(MlirModule module) { + return wrap(new PDLPatternModule(unwrap(module))); +} + +MlirRewritePatternSet mlirRewritePatternSetGet(MlirContext context) { + return wrap(new RewritePatternSet(unwrap(context))); +} + +MlirRewritePatternSet +mlirPatternSetAddOwnedPDLPattern(MlirRewritePatternSet patterns, + MlirPDLPatternModule pdlPattern) { + return wrap(&unwrap(patterns)->add(std::move(*(unwrap(pdlPattern))))); +} + +MlirLogicalResult +mlirApplyOwnedPatternsGreedilyOnRegion(MlirRegion region, + MlirRewritePatternSet patterns) { + return wrap(applyPatternsAndFoldGreedily(*unwrap(region), + std::move(*unwrap(patterns)))); +} + +MlirLogicalResult +mlirApplyOwnedPatternsGreedilyOnOperation(MlirOperation op, + MlirRewritePatternSet patterns) { + return wrap( + applyPatternsAndFoldGreedily(unwrap(op), std::move(*unwrap(patterns)))); +} diff --git a/mlir/test/CAPI/CMakeLists.txt b/mlir/test/CAPI/CMakeLists.txt --- a/mlir/test/CAPI/CMakeLists.txt +++ b/mlir/test/CAPI/CMakeLists.txt @@ -54,6 +54,15 @@ MLIRCAPITransforms ) +_add_capi_test_executable(mlir-capi-pattern-match-test + pattern_match.c + LINK_LIBS PRIVATE + MLIRCAPIFunc + MLIRCAPIIR + MLIRCAPIRegisterEverything + MLIRCAPITransforms +) + _add_capi_test_executable(mlir-capi-pdl-test pdl.c LINK_LIBS PRIVATE diff --git a/mlir/test/CAPI/pattern_match.c b/mlir/test/CAPI/pattern_match.c new file mode 100644 --- /dev/null +++ b/mlir/test/CAPI/pattern_match.c @@ -0,0 +1,129 @@ +//===- pattern_match.c - Simple test of C APIs ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +/* RUN: mlir-capi-pattern-match-test 2>&1 | FileCheck %s + */ + +#include "mlir-c/IR.h" +#include "mlir-c/PatternMatch.h" +#include "mlir-c/RegisterEverything.h" +#include "mlir-c/Support.h" + +static void registerAllUpstreamDialects(MlirContext ctx) { + MlirDialectRegistry registry = mlirDialectRegistryCreate(); + mlirRegisterAllDialects(registry); + mlirContextAppendDialectRegistry(ctx, registry); + mlirDialectRegistryDestroy(registry); +} + +static MlirModule createModule(MlirContext ctx) { + MlirLocation location = mlirLocationUnknownGet(ctx); + MlirModule moduleOp = mlirModuleCreateEmpty(location); + MlirBlock moduleBody = mlirModuleGetBody(moduleOp); + + MlirRegion funcBodyRegion = mlirRegionCreate(); + MlirType intType = + mlirTypeParseGet(ctx, mlirStringRefCreateFromCString("i32")); + MlirType funcBodyArgTypes[] = {intType, intType}; + MlirLocation funcBodyArgLocs[] = {location, location}; + MlirBlock funcBody = + mlirBlockCreate(sizeof(funcBodyArgTypes) / sizeof(MlirType), + funcBodyArgTypes, funcBodyArgLocs); + mlirRegionAppendOwnedBlock(funcBodyRegion, funcBody); + MlirAttribute funcTypeAttr = mlirAttributeParseGet( + ctx, mlirStringRefCreateFromCString("(i32, i32) -> (i32)")); + MlirAttribute funcNameAttr = + mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("\"add\"")); + MlirNamedAttribute funcAttrs[] = { + mlirNamedAttributeGet( + mlirIdentifierGet(ctx, + mlirStringRefCreateFromCString("function_type")), + funcTypeAttr), + mlirNamedAttributeGet( + mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("sym_name")), + funcNameAttr)}; + MlirOperationState funcState = mlirOperationStateGet( + mlirStringRefCreateFromCString("func.func"), location); + mlirOperationStateAddAttributes(&funcState, 2, funcAttrs); + mlirOperationStateAddOwnedRegions(&funcState, 1, &funcBodyRegion); + MlirOperation func = mlirOperationCreate(&funcState); + mlirBlockInsertOwnedOperation(moduleBody, 0, func); + + MlirOperationState addState = mlirOperationStateGet( + mlirStringRefCreateFromCString("arith.addi"), location); + mlirOperationStateAddResults(&addState, 1, &intType); + MlirValue operands[] = {mlirBlockGetArgument(funcBody, 0), + mlirBlockGetArgument(funcBody, 1)}; + mlirOperationStateAddOperands(&addState, 2, operands); + MlirOperation add = mlirOperationCreate(&addState); + mlirBlockAppendOwnedOperation(funcBody, add); + + MlirOperationState returnState = mlirOperationStateGet( + mlirStringRefCreateFromCString("func.return"), location); + MlirValue operand = mlirOperationGetResult(add, 0); + mlirOperationStateAddOperands(&returnState, 1, &operand); + MlirOperation ret = mlirOperationCreate(&returnState); + mlirBlockAppendOwnedOperation(funcBody, ret); + + return moduleOp; +} + +static void testPDLPattern(MlirContext ctx) { + MlirModule moduleOp = createModule(ctx); + + mlirOperationDump(mlirModuleGetOperation(moduleOp)); + // clang-format off + // CHECK: module { + // CHECK: func.func @add(%arg0: i32, %arg1: i32) -> i32 { + // CHECK: %0 = arith.addi %arg0, %arg1 : i32 + // CHECK: return %0 : i32 + // CHECK: } + // CHECK: } + // clang-format on + MlirPDLPatternModule pdlModule = mlirPDLPatternGet( + mlirModuleCreateParse(ctx, mlirStringRefCreateFromCString("\ + module @test {\ + pdl.pattern : benefit(0) {\ + %0 = operand\ + %1 = operand\ + %2 = types\ + %3 = operation \"arith.addi\"(%0, %1 : !pdl.value, !pdl.value) -> (%2 : !pdl.range)\ + rewrite %3 {\ + %4 = operation \"arith.subi\"(%0, %1 : !pdl.value, !pdl.value)\ + replace %3 with %4\ + }\ + }\ + }\ + "))); + + MlirRewritePatternSet set = mlirRewritePatternSetGet(ctx); + mlirPatternSetAddOwnedPDLPattern(set, pdlModule); + mlirApplyOwnedPatternsGreedilyOnOperation(mlirModuleGetOperation(moduleOp), + set); + + mlirOperationDump(mlirModuleGetOperation(moduleOp)); + // clang-format off + // CHECK: module { + // CHECK: func.func @add(%arg0: i32, %arg1: i32) -> i32 { + // CHECK: %0 = arith.subi %arg0, %arg1 : i32 + // CHECK: return %0 : i32 + // CHECK: } + // CHECK: } + // clang-format on +} + +int main(void) { + MlirContext ctx = mlirContextCreate(); + registerAllUpstreamDialects(ctx); + mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("func")); + mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("arith")); + mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("pdl")); + + testPDLPattern(ctx); +} diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt --- a/mlir/test/CMakeLists.txt +++ b/mlir/test/CMakeLists.txt @@ -91,6 +91,7 @@ mlir-capi-ir-test mlir-capi-llvm-test mlir-capi-pass-test + mlir-capi-pattern-match-test mlir-capi-pdl-test mlir-capi-quant-test mlir-capi-sparse-tensor-test diff --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py --- a/mlir/test/lit.cfg.py +++ b/mlir/test/lit.cfg.py @@ -90,6 +90,7 @@ "mlir-capi-ir-test", "mlir-capi-llvm-test", "mlir-capi-pass-test", + "mlir-capi-pattern-match-test", "mlir-capi-pdl-test", "mlir-capi-quant-test", "mlir-capi-sparse-tensor-test",