diff --git a/mlir/include/mlir-c/IntegerSet.h b/mlir/include/mlir-c/IntegerSet.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir-c/IntegerSet.h @@ -0,0 +1,131 @@ +//===-- mlir-c/IntegerSet.h - C API for MLIR Affine maps ----------*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_INTEGERSET_H +#define MLIR_C_INTEGERSET_H + +#include "mlir-c/AffineExpr.h" + +#ifdef __cplusplus +extern "C" { +#endif + +//===----------------------------------------------------------------------===// +// Opaque type declarations. +// +// Types are exposed to C bindings as structs containing opaque pointers. They +// are not supposed to be inspected from C. This allows the underlying +// representation to change without affecting the API users. The use of structs +// instead of typedefs enables some type safety as structs are not implicitly +// convertible to each other. +// +// Instances of these types may or may not own the underlying object. The +// ownership semantics is defined by how an instance of the type was obtained. +//===----------------------------------------------------------------------===// + +#define DEFINE_C_API_STRUCT(name, storage) \ + struct name { \ + storage *ptr; \ + }; \ + typedef struct name name + +DEFINE_C_API_STRUCT(MlirIntegerSet, const void); + +#undef DEFINE_C_API_STRUCT + +/// Gets the context in which the given integer set lives. +MLIR_CAPI_EXPORTED MlirContext mlirIntegerSetGetContext(MlirIntegerSet set); + +/// Checks whether an integer set is a null object. +static inline bool mlirIntegerSetIsNull(MlirIntegerSet set) { return !set.ptr; } + +/// Checks if two integer set objects are equal. This is a "shallow" comparison +/// of two objects. Only the sets with some small number of constraints are +/// uniqued and compare equal here. Set objects that represent the same integer +/// set with different constraints may be considered non-equal by this check. +/// Set difference followed by an (expensive) emptiness check should be used to +/// check equivalence of the underlying integer sets. +MLIR_CAPI_EXPORTED bool mlirIntegerSetEqual(MlirIntegerSet s1, + MlirIntegerSet s2); + +/// Prints an integer set by sending chunks of the string representation and +/// forwarding `userData to `callback`. Note that the callback may be called +/// several times with consecutive chunks of the string. +MLIR_CAPI_EXPORTED void mlirIntegerSetPrint(MlirIntegerSet set, + MlirStringCallback callback, + void *userData); + +/// Prints an integer set to the standard error stream. +MLIR_CAPI_EXPORTED void mlirIntegerSetDump(MlirIntegerSet set); + +/// Gets or creates a new canonically empty integer set with the give number of +/// dimensions and symbols in the given context. +MLIR_CAPI_EXPORTED MlirIntegerSet mlirIntegerSetEmptyGet(MlirContext context, + intptr_t numDims, + intptr_t numSymbols); + +/// Gets or creates a new integer set in the given context. The set is defined +/// by a list of affine constraints, with the given number of input dimensions +/// and symbols, which are treated as either equalities (eqFlags is 1) or +/// inequalities (eqFlags is 0). Both `constraints` and `eqFlags` are expected +/// to point to at least `numConstraint` consecutive values. +MLIR_CAPI_EXPORTED MlirIntegerSet +mlirIntegerSetGet(MlirContext context, intptr_t numDims, intptr_t numSymbols, + intptr_t numConstraints, const MlirAffineExpr *constraints, + const bool *eqFlags); + +/// Gets or creates a new integer set in which the values and dimensions of the +/// given set are replaced with the given affine expressions. `dimReplacements` +/// and `symbolReplacements` are expected to point to at least as many +/// consecutive expressions as the given set has dimensions and symbols, +/// respectively. The new set will have `numResultDims` and `numResultSymbols` +/// dimensions and symbols, respectively. +MLIR_CAPI_EXPORTED MlirIntegerSet mlirIntegerSetReplaceGet( + MlirIntegerSet set, const MlirAffineExpr *dimReplacements, + const MlirAffineExpr *symbolReplacements, intptr_t numResultDims, + intptr_t numResultSymbols); + +/// Checks whether the given set is a canonical empty set, e.g., the set +/// returned by mlirIntegerSetEmptyGet. +MLIR_CAPI_EXPORTED bool mlirIntegerSetIsCanonicalEmpty(MlirIntegerSet set); + +/// Returns the number of dimensions in the given set. +MLIR_CAPI_EXPORTED intptr_t mlirIntegerSetGetNumDims(MlirIntegerSet set); + +/// Returns the number of symbols in the given set. +MLIR_CAPI_EXPORTED intptr_t mlirIntegerSetGetNumSymbols(MlirIntegerSet set); + +/// Returns the number of inputs (dimensions + symbols) in the given set. +MLIR_CAPI_EXPORTED intptr_t mlirIntegerSetGetNumInputs(MlirIntegerSet set); + +/// Returns the number of constraints (equalities + inequalities) in the given +/// set. +MLIR_CAPI_EXPORTED intptr_t mlirIntegerSetGetNumConstraints(MlirIntegerSet set); + +/// Returns the number of equalities in the given set. +MLIR_CAPI_EXPORTED intptr_t mlirIntegerSetGetNumEqualities(MlirIntegerSet set); + +/// Returns the number of inequalities in the given set. +MLIR_CAPI_EXPORTED intptr_t +mlirIntegerSetGetNumInequalities(MlirIntegerSet set); + +/// Returns `pos`-th constraint of the set. +MLIR_CAPI_EXPORTED MlirAffineExpr +mlirIntegerSetGetConstraint(MlirIntegerSet set, intptr_t pos); + +/// Returns `true` of the `pos`-th constraint of the set is an equality +/// constraint, `false` otherwise. +MLIR_CAPI_EXPORTED bool mlirIntegerSetIsConstraintEq(MlirIntegerSet set, + intptr_t pos); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_INTEGERSET_H diff --git a/mlir/include/mlir/CAPI/IntegerSet.h b/mlir/include/mlir/CAPI/IntegerSet.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/CAPI/IntegerSet.h @@ -0,0 +1,24 @@ +//===- IntegerSet.h - C API Utils for Integer Sets --------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains declarations of implementation details of the C API for +// MLIR IntegerSets. This file should not be included from C++ code other than C +// API implementation nor from C code. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CAPI_INTEGERSET_H +#define MLIR_CAPI_INTEGERSET_H + +#include "mlir-c/IntegerSet.h" +#include "mlir/CAPI/Wrap.h" +#include "mlir/IR/IntegerSet.h" + +DEFINE_C_API_METHODS(MlirIntegerSet, mlir::IntegerSet); + +#endif // MLIR_CAPI_INTEGERSET_H diff --git a/mlir/include/mlir/IR/IntegerSet.h b/mlir/include/mlir/IR/IntegerSet.h --- a/mlir/include/mlir/IR/IntegerSet.h +++ b/mlir/include/mlir/IR/IntegerSet.h @@ -104,6 +104,15 @@ friend ::llvm::hash_code hash_value(IntegerSet arg); + /// Methods supporting C API. + const void *getAsOpaquePointer() const { + return static_cast(set); + } + static IntegerSet getFromOpaquePointer(const void *pointer) { + return IntegerSet( + reinterpret_cast(const_cast(pointer))); + } + private: ImplType *set; /// Sets with constraints fewer than kUniquingThreshold are uniqued. diff --git a/mlir/lib/CAPI/IR/CMakeLists.txt b/mlir/lib/CAPI/IR/CMakeLists.txt --- a/mlir/lib/CAPI/IR/CMakeLists.txt +++ b/mlir/lib/CAPI/IR/CMakeLists.txt @@ -5,6 +5,7 @@ BuiltinAttributes.cpp BuiltinTypes.cpp Diagnostics.cpp + IntegerSet.cpp IR.cpp Pass.cpp Support.cpp diff --git a/mlir/lib/CAPI/IR/IntegerSet.cpp b/mlir/lib/CAPI/IR/IntegerSet.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/CAPI/IR/IntegerSet.cpp @@ -0,0 +1,103 @@ +//===- IntegerSet.cpp - C API for MLIR Integer Sets -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/IntegerSet.h" +#include "mlir-c/AffineExpr.h" +#include "mlir/CAPI/AffineExpr.h" +#include "mlir/CAPI/IR.h" +#include "mlir/CAPI/IntegerSet.h" +#include "mlir/CAPI/Utils.h" +#include "mlir/IR/IntegerSet.h" + +using namespace mlir; + +MlirContext mlirIntegerSetGetContext(MlirIntegerSet set) { + return wrap(unwrap(set).getContext()); +} + +bool mlirIntegerSetEqual(MlirIntegerSet s1, MlirIntegerSet s2) { + return unwrap(s1) == unwrap(s2); +} + +void mlirIntegerSetPrint(MlirIntegerSet set, MlirStringCallback callback, + void *userData) { + mlir::detail::CallbackOstream stream(callback, userData); + unwrap(set).print(stream); +} + +void mlirIntegerSetDump(MlirIntegerSet set) { unwrap(set).dump(); } + +MlirIntegerSet mlirIntegerSetEmptyGet(MlirContext context, intptr_t numDims, + intptr_t numSymbols) { + return wrap(IntegerSet::getEmptySet(static_cast(numDims), + static_cast(numSymbols), + unwrap(context))); +} + +MlirIntegerSet mlirIntegerSetGet(MlirContext context, intptr_t numDims, + intptr_t numSymbols, intptr_t numConstraints, + const MlirAffineExpr *constraints, + const bool *eqFlags) { + SmallVector mlirConstraints; + (void)unwrapList(static_cast(numConstraints), constraints, + mlirConstraints); + return wrap(IntegerSet::get( + static_cast(numDims), static_cast(numSymbols), + mlirConstraints, + llvm::makeArrayRef(eqFlags, static_cast(numConstraints)))); +} + +MlirIntegerSet +mlirIntegerSetReplaceGet(MlirIntegerSet set, + const MlirAffineExpr *dimReplacements, + const MlirAffineExpr *symbolReplacements, + intptr_t numResultDims, intptr_t numResultSymbols) { + SmallVector mlirDims, mlirSymbols; + (void)unwrapList(unwrap(set).getNumDims(), dimReplacements, mlirDims); + (void)unwrapList(unwrap(set).getNumSymbols(), symbolReplacements, + mlirSymbols); + return wrap(unwrap(set).replaceDimsAndSymbols( + mlirDims, mlirSymbols, static_cast(numResultDims), + static_cast(numResultSymbols))); +} + +bool mlirIntegerSetIsCanonicalEmpty(MlirIntegerSet set) { + return unwrap(set).isEmptyIntegerSet(); +} + +intptr_t mlirIntegerSetGetNumDims(MlirIntegerSet set) { + return static_cast(unwrap(set).getNumDims()); +} + +intptr_t mlirIntegerSetGetNumSymbols(MlirIntegerSet set) { + return static_cast(unwrap(set).getNumSymbols()); +} + +intptr_t mlirIntegerSetGetNumInputs(MlirIntegerSet set) { + return static_cast(unwrap(set).getNumInputs()); +} + +intptr_t mlirIntegerSetGetNumConstraints(MlirIntegerSet set) { + return static_cast(unwrap(set).getNumConstraints()); +} + +intptr_t mlirIntegerSetGetNumEqualities(MlirIntegerSet set) { + return static_cast(unwrap(set).getNumEqualities()); +} + +intptr_t mlirIntegerSetGetNumInequalities(MlirIntegerSet set) { + return static_cast(unwrap(set).getNumInequalities()); +} + +MlirAffineExpr mlirIntegerSetGetConstraint(MlirIntegerSet set, intptr_t pos) { + return wrap(unwrap(set).getConstraint(static_cast(pos))); +} + +bool mlirIntegerSetIsConstraintEq(MlirIntegerSet set, intptr_t pos) { + return unwrap(set).isEq(pos); +} diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -17,6 +17,7 @@ #include "mlir-c/BuiltinTypes.h" #include "mlir-c/Diagnostics.h" #include "mlir-c/Dialect/Standard.h" +#include "mlir-c/IntegerSet.h" #include "mlir-c/Registration.h" #include @@ -1325,6 +1326,85 @@ return 0; } +int printIntegerSet(MlirContext ctx) { + MlirIntegerSet emptySet = mlirIntegerSetEmptyGet(ctx, 2, 1); + + // CHECK-LABEL: @printIntegerSet + fprintf(stderr, "@printIntegerSet"); + + // CHECK: (d0, d1)[s0] : (1 == 0) + mlirIntegerSetDump(emptySet); + + if (!mlirIntegerSetIsCanonicalEmpty(emptySet)) + return 1; + + MlirIntegerSet anotherEmptySet = mlirIntegerSetEmptyGet(ctx, 2, 1); + if (!mlirIntegerSetEqual(emptySet, anotherEmptySet)) + return 2; + + // Construct a set constrained by: + // d0 - s0 == 0, + // d1 - 42 >= 0. + MlirAffineExpr negOne = mlirAffineConstantExprGet(ctx, -1); + MlirAffineExpr negFortyTwo = mlirAffineConstantExprGet(ctx, -42); + MlirAffineExpr d0 = mlirAffineDimExprGet(ctx, 0); + MlirAffineExpr d1 = mlirAffineDimExprGet(ctx, 1); + MlirAffineExpr s0 = mlirAffineSymbolExprGet(ctx, 0); + MlirAffineExpr negS0 = mlirAffineMulExprGet(negOne, s0); + MlirAffineExpr d0minusS0 = mlirAffineAddExprGet(d0, negS0); + MlirAffineExpr d1minus42 = mlirAffineAddExprGet(d1, negFortyTwo); + MlirAffineExpr constraints[] = {d0minusS0, d1minus42}; + bool flags[] = {true, false}; + + MlirIntegerSet set = mlirIntegerSetGet(ctx, 2, 1, 2, constraints, flags); + // CHECK: (d0, d1)[s0] : ( + // CHECK-DAG: d0 - s0 == 0 + // CHECK-DAG: d1 - 42 >= 0 + mlirIntegerSetDump(set); + + // Transform d1 into s0. + MlirAffineExpr s1 = mlirAffineSymbolExprGet(ctx, 1); + MlirAffineExpr repl[] = {d0, s1}; + MlirIntegerSet replaced = mlirIntegerSetReplaceGet(set, repl, &s0, 1, 2); + // CHECK: (d0)[s0, s1] : ( + // CHECK-DAG: d0 - s0 == 0 + // CHECK-DAG: s1 - 42 >= 0 + mlirIntegerSetDump(replaced); + + if (mlirIntegerSetGetNumDims(set) != 2) + return 3; + if (mlirIntegerSetGetNumDims(replaced) != 1) + return 4; + + if (mlirIntegerSetGetNumSymbols(set) != 1) + return 5; + if (mlirIntegerSetGetNumSymbols(replaced) != 2) + return 6; + + if (mlirIntegerSetGetNumInputs(set) != 3) + return 7; + + if (mlirIntegerSetGetNumConstraints(set) != 2) + return 8; + + if (mlirIntegerSetGetNumEqualities(set) != 1) + return 9; + + if (mlirIntegerSetGetNumInequalities(set) != 1) + return 10; + + MlirAffineExpr cstr1 = mlirIntegerSetGetConstraint(set, 0); + MlirAffineExpr cstr2 = mlirIntegerSetGetConstraint(set, 1); + bool isEq1 = mlirIntegerSetIsConstraintEq(set, 0); + bool isEq2 = mlirIntegerSetIsConstraintEq(set, 1); + if (!mlirAffineExprEqual(cstr1, isEq1 ? d0minusS0 : d1minus42)) + return 11; + if (!mlirAffineExprEqual(cstr2, isEq2 ? d0minusS0 : d1minus42)) + return 12; + + return 0; +} + int registerOnlyStd() { MlirContext ctx = mlirContextCreate(); // The built-in dialect is always loaded. @@ -1429,8 +1509,10 @@ return 6; if (affineMapFromExprs(ctx)) return 7; - if (registerOnlyStd()) + if (printIntegerSet(ctx)) return 8; + if (registerOnlyStd()) + return 9; mlirContextDestroy(ctx);