diff --git a/mlir/include/mlir-c/Dialect/Quant.h b/mlir/include/mlir-c/Dialect/Quant.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir-c/Dialect/Quant.h @@ -0,0 +1,199 @@ +//===-- mlir-c/Dialect/LLVM.h - C API for LLVM --------------------*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_DIALECT_QUANT_H +#define MLIR_C_DIALECT_QUANT_H + +#include "mlir-c/IR.h" +#include "mlir-c/Registration.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(quant, quant); + +//===---------------------------------------------------------------------===// +// QuantizedType +//===---------------------------------------------------------------------===// + +/// Returns `true` if the given type is a quantization dialect type. +MLIR_CAPI_EXPORTED bool mlirTypeIsAQuantizedType(MlirType type); + +/// Returns the bit flag used to indicate signedness of a quantized type. +MLIR_CAPI_EXPORTED unsigned mlirQuantizedTypeGetSignedFlag(); + +/// Returns the minimum possible value stored by a quantized type. +MLIR_CAPI_EXPORTED int64_t mlirQuantizedTypeGetDefaultMinimumForInteger( + bool isSigned, unsigned integralWidth); + +/// Returns the maximum possible value stored by a quantized type. +MLIR_CAPI_EXPORTED int64_t mlirQuantizedTypeGetDefaultMaximumForInteger( + bool isSigned, unsigned integralWidth); + +/// Gets the original type approximated by the given quantized type. +MLIR_CAPI_EXPORTED MlirType mlirQuantizedTypeGetExpressedType(MlirType type); + +/// Gets the flags associated with the given quantized type. +MLIR_CAPI_EXPORTED unsigned mlirQuantizedTypeGetFlags(MlirType type); + +/// Returns `true` if the given type is signed, `false` otherwise. +MLIR_CAPI_EXPORTED bool mlirQuantizedTypeIsSigned(MlirType type); + +/// Returns the underlying type used to store the values. +MLIR_CAPI_EXPORTED MlirType mlirQuantizedTypeGetStorageType(MlirType type); + +/// Returns the minimum value that the storage type of the given quantized type +/// can take. +MLIR_CAPI_EXPORTED int64_t mlirQuantizedTypeGetStorageTypeMin(MlirType type); + +/// Returns the maximum value that the storage type of the given quantized type +/// can take. +MLIR_CAPI_EXPORTED int64_t mlirQuantizedTypeGetStorageTypeMax(MlirType type); + +/// Returns the integral bitwidth that the storage type of the given quantized +/// type can represent exactly. +MLIR_CAPI_EXPORTED unsigned +mlirQuantizedTypeGetStorageTypeIntegralWidth(MlirType type); + +/// Returns `true` if the `candidate` type is compatible with the given +/// quantized `type`. +MLIR_CAPI_EXPORTED bool +mlirQuantizedTypeIsCompatibleExpressedType(MlirType type, MlirType candidate); + +/// Returns the element type of the given quantized type as another quantized +/// type. +MLIR_CAPI_EXPORTED MlirType +mlirQuantizedTypeGetQuantizedElementType(MlirType type); + +/// Casts from a type based on the storage type of the given type to a +/// corresponding type based on the given type. Returns a null type if the cast +/// is not valid. +MLIR_CAPI_EXPORTED MlirType +mlirQuantizedTypeCastFromStorageType(MlirType type, MlirType candidate); + +/// Casts from a type based on a quantized type to a corresponding typed based +/// on the storage type. Returns a null type if the cast is not valid. +MLIR_CAPI_EXPORTED MlirType mlirQuantizedTypeCastToStorageType(MlirType type); + +/// Casts from a type based on the expressed type of the given type to a +/// corresponding type based on the given type. Returns a null type if the cast +/// is not valid. +MLIR_CAPI_EXPORTED MlirType +mlirQuantizedTypeCastFromExpressedType(MlirType type, MlirType candidate); + +/// Casts from a type based on a quantized type to a corresponding typed based +/// on the expressed type. Returns a null type if the cast is not valid. +MLIR_CAPI_EXPORTED MlirType mlirQuantizedTypeCastToExpressedType(MlirType type); + +/// Casts from a type based on the expressed type of the given quantized type to +/// equivalent type based on storage type of the same quantized type. +MLIR_CAPI_EXPORTED MlirType +mlirQuantizedTypeCastExpressedToStorageType(MlirType type, MlirType candidate); + +//===---------------------------------------------------------------------===// +// AnyQuantizedType +//===---------------------------------------------------------------------===// + +/// Returns `true` if the given type is an AnyQuantizedType. +MLIR_CAPI_EXPORTED bool mlirTypeIsAAnyQuantizedType(MlirType type); + +/// Creates an instance of AnyQuantizedType with the given parameters in the +/// same context as `storageType` and returns it. The instance is owned by the +/// context. +MLIR_CAPI_EXPORTED MlirType mlirAnyQuantizedTypeGet(unsigned flags, + MlirType storageType, + MlirType expressedType, + int64_t storageTypeMin, + int64_t storageTypeMax); + +//===---------------------------------------------------------------------===// +// UniformQuantizedType +//===---------------------------------------------------------------------===// + +/// Returns `true` if the given type is a UniformQuantizedType. +MLIR_CAPI_EXPORTED bool mlirTypeIsAUniformQuantizedType(MlirType type); + +/// Creates an instance of UniformQuantizedType with the given parameters in the +/// same context as `storageType` and returns it. The instance is owned by the +/// context. +MLIR_CAPI_EXPORTED MlirType mlirUniformQuantizedTypeGet( + unsigned flags, MlirType storageType, MlirType expressedType, double scale, + int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax); + +/// Returns the scale of the given uniform quantized type. +MLIR_CAPI_EXPORTED double mlirUniformQuantizedTypeGetScale(MlirType type); + +/// Returns the zero point of the given uniform quantized type. +MLIR_CAPI_EXPORTED int64_t mlirUniformQuantizedTypeGetZeroPoint(MlirType type); + +/// Returns `true` if the given uniform quantized type is fixed-point. +MLIR_CAPI_EXPORTED bool mlirUniformQuantizedTypeIsFixedPoint(MlirType type); + +//===---------------------------------------------------------------------===// +// UniformQuantizedPerAxisType +//===---------------------------------------------------------------------===// + +/// Returns `true` if the given type is a UniformQuantizedPerAxisType. +MLIR_CAPI_EXPORTED bool mlirTypeIsAUniformQuantizedPerAxisType(MlirType type); + +/// Creates an instance of UniformQuantizedPerAxisType with the given parameters +/// in the same context as `storageType` and returns it. `scales` and +/// `zeroPoints` point to `nDims` number of elements. The instance is owned +/// by the context. +MLIR_CAPI_EXPORTED MlirType mlirUniformQuantizedPerAxisTypeGet( + unsigned flags, MlirType storageType, MlirType expressedType, + intptr_t nDims, double *scales, int64_t *zeroPoints, + int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax); + +/// Returns the number of axes in the given quantized per-axis type. +MLIR_CAPI_EXPORTED intptr_t +mlirUniformQuantizedPerAxisTypeGetNumDims(MlirType type); + +/// Returns `pos`-th scale of the given quantized per-axis type. +MLIR_CAPI_EXPORTED double mlirUniformQuantizedPerAxisTypeGetScale(MlirType type, + intptr_t pos); + +/// Returns `pos`-th zero point of the given quantized per-axis type. +MLIR_CAPI_EXPORTED int64_t +mlirUniformQuantizedPerAxisTypeGetZeroPoint(MlirType type, intptr_t pos); + +/// Returns the index of the quantized dimension in the given quantized per-axis +/// type. +MLIR_CAPI_EXPORTED int32_t +mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(MlirType type); + +/// Returns `true` if the given uniform quantized per-axis type is fixed-point. +MLIR_CAPI_EXPORTED bool +mlirUniformQuantizedPerAxisTypeIsFixedPoint(MlirType type); + +//===---------------------------------------------------------------------===// +// CalibratedQuantizedType +//===---------------------------------------------------------------------===// + +/// Returns `true` if the given type is a CalibratedQuantizedType. +MLIR_CAPI_EXPORTED bool mlirTypeIsACalibratedQuantizedType(MlirType type); + +/// Creates an instance of CalibratedQuantizedType with the given parameters +/// in the same context as `expressedType` and returns it. The instance is owned +/// by the context. +MLIR_CAPI_EXPORTED MlirType +mlirCalibratedQuantizedTypeGet(MlirType expressedType, double min, double max); + +/// Returns the min value of the given calibrated quantized type. +MLIR_CAPI_EXPORTED double mlirCalibratedQuantizedTypeGetMin(MlirType type); + +/// Returns the max value of the given calibrated quantized type. +MLIR_CAPI_EXPORTED double mlirCalibratedQuantizedTypeGetMax(MlirType type); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_DIALECT_QUANT_H diff --git a/mlir/lib/CAPI/Dialect/CMakeLists.txt b/mlir/lib/CAPI/Dialect/CMakeLists.txt --- a/mlir/lib/CAPI/Dialect/CMakeLists.txt +++ b/mlir/lib/CAPI/Dialect/CMakeLists.txt @@ -97,3 +97,12 @@ MLIRCAPIIR MLIRTensor ) + +add_mlir_upstream_c_api_library(MLIRCAPIQuant + Quant.cpp + + PARTIAL_SOURCES_INTENDED + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIRQuant +) diff --git a/mlir/lib/CAPI/Dialect/Quant.cpp b/mlir/lib/CAPI/Dialect/Quant.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/CAPI/Dialect/Quant.cpp @@ -0,0 +1,224 @@ +//===- LLVM.cpp - C Interface for Quant dialect ---------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/Quant.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/Dialect/Quant/QuantOps.h" +#include "mlir/Dialect/Quant/QuantTypes.h" + +using namespace mlir; + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(quant, quant, quant::QuantizationDialect) + +//===---------------------------------------------------------------------===// +// QuantizedType +//===---------------------------------------------------------------------===// + +bool mlirTypeIsAQuantizedType(MlirType type) { + return unwrap(type).isa(); +} + +unsigned mlirQuantizedTypeGetSignedFlag() { + return quant::QuantizationFlags::Signed; +} + +int64_t mlirQuantizedTypeGetDefaultMinimumForInteger(bool isSigned, + unsigned integralWidth) { + return quant::QuantizedType::getDefaultMinimumForInteger(isSigned, + integralWidth); +} + +int64_t mlirQuantizedTypeGetDefaultMaximumForInteger(bool isSigned, + unsigned integralWidth) { + return quant::QuantizedType::getDefaultMaximumForInteger(isSigned, + integralWidth); +} + +MlirType mlirQuantizedTypeGetExpressedType(MlirType type) { + return wrap(unwrap(type).cast().getExpressedType()); +} + +unsigned mlirQuantizedTypeGetFlags(MlirType type) { + return unwrap(type).cast().getFlags(); +} + +bool mlirQuantizedTypeIsSigned(MlirType type) { + return unwrap(type).cast().isSigned(); +} + +MlirType mlirQuantizedTypeGetStorageType(MlirType type) { + return wrap(unwrap(type).cast().getStorageType()); +} + +int64_t mlirQuantizedTypeGetStorageTypeMin(MlirType type) { + return unwrap(type).cast().getStorageTypeMin(); +} + +int64_t mlirQuantizedTypeGetStorageTypeMax(MlirType type) { + return unwrap(type).cast().getStorageTypeMax(); +} + +unsigned mlirQuantizedTypeGetStorageTypeIntegralWidth(MlirType type) { + return unwrap(type) + .cast() + .getStorageTypeIntegralWidth(); +} + +bool mlirQuantizedTypeIsCompatibleExpressedType(MlirType type, + MlirType candidate) { + return unwrap(type).cast().isCompatibleExpressedType( + unwrap(candidate)); +} + +MlirType mlirQuantizedTypeGetQuantizedElementType(MlirType type) { + return wrap(quant::QuantizedType::getQuantizedElementType(unwrap(type))); +} + +MlirType mlirQuantizedTypeCastFromStorageType(MlirType type, + MlirType candidate) { + return wrap(unwrap(type).cast().castFromStorageType( + unwrap(candidate))); +} + +MlirType mlirQuantizedTypeCastToStorageType(MlirType type) { + return wrap(quant::QuantizedType::castToStorageType( + unwrap(type).cast())); +} + +MlirType mlirQuantizedTypeCastFromExpressedType(MlirType type, + MlirType candidate) { + return wrap(unwrap(type).cast().castFromExpressedType( + unwrap(candidate))); +} + +MlirType mlirQuantizedTypeCastToExpressedType(MlirType type) { + return wrap(quant::QuantizedType::castToExpressedType(unwrap(type))); +} + +MlirType mlirQuantizedTypeCastExpressedToStorageType(MlirType type, + MlirType candidate) { + return wrap( + unwrap(type).cast().castExpressedToStorageType( + unwrap(candidate))); +} + +//===---------------------------------------------------------------------===// +// AnyQuantizedType +//===---------------------------------------------------------------------===// + +bool mlirTypeIsAAnyQuantizedType(MlirType type) { + return unwrap(type).isa(); +} + +MlirType mlirAnyQuantizedTypeGet(unsigned flags, MlirType storageType, + MlirType expressedType, int64_t storageTypeMin, + int64_t storageTypeMax) { + return wrap(quant::AnyQuantizedType::get(flags, unwrap(storageType), + unwrap(expressedType), + storageTypeMin, storageTypeMax)); +} + +//===---------------------------------------------------------------------===// +// UniformQuantizedType +//===---------------------------------------------------------------------===// + +bool mlirTypeIsAUniformQuantizedType(MlirType type) { + return unwrap(type).isa(); +} + +MlirType mlirUniformQuantizedTypeGet(unsigned flags, MlirType storageType, + MlirType expressedType, double scale, + int64_t zeroPoint, int64_t storageTypeMin, + int64_t storageTypeMax) { + return wrap(quant::UniformQuantizedType::get( + flags, unwrap(storageType), unwrap(expressedType), scale, zeroPoint, + storageTypeMin, storageTypeMax)); +} + +double mlirUniformQuantizedTypeGetScale(MlirType type) { + return unwrap(type).cast().getScale(); +} + +int64_t mlirUniformQuantizedTypeGetZeroPoint(MlirType type) { + return unwrap(type).cast().getZeroPoint(); +} + +bool mlirUniformQuantizedTypeIsFixedPoint(MlirType type) { + return unwrap(type).cast().isFixedPoint(); +} + +//===---------------------------------------------------------------------===// +// UniformQuantizedPerAxisType +//===---------------------------------------------------------------------===// + +bool mlirTypeIsAUniformQuantizedPerAxisType(MlirType type) { + return unwrap(type).isa(); +} + +MlirType mlirUniformQuantizedPerAxisTypeGet( + unsigned flags, MlirType storageType, MlirType expressedType, + intptr_t nDims, double *scales, int64_t *zeroPoints, + int32_t quantizedDimension, int64_t storageTypeMin, + int64_t storageTypeMax) { + return wrap(quant::UniformQuantizedPerAxisType::get( + flags, unwrap(storageType), unwrap(expressedType), + llvm::makeArrayRef(scales, nDims), llvm::makeArrayRef(zeroPoints, nDims), + quantizedDimension, storageTypeMin, storageTypeMax)); +} + +intptr_t mlirUniformQuantizedPerAxisTypeGetNumDims(MlirType type) { + return unwrap(type) + .cast() + .getScales() + .size(); +} + +double mlirUniformQuantizedPerAxisTypeGetScale(MlirType type, intptr_t pos) { + return unwrap(type) + .cast() + .getScales()[pos]; +} + +int64_t mlirUniformQuantizedPerAxisTypeGetZeroPoint(MlirType type, + intptr_t pos) { + return unwrap(type) + .cast() + .getZeroPoints()[pos]; +} + +int32_t mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(MlirType type) { + return unwrap(type) + .cast() + .getQuantizedDimension(); +} + +bool mlirUniformQuantizedPerAxisTypeIsFixedPoint(MlirType type) { + return unwrap(type).cast().isFixedPoint(); +} + +//===---------------------------------------------------------------------===// +// CalibratedQuantizedType +//===---------------------------------------------------------------------===// + +bool mlirTypeIsACalibratedQuantizedType(MlirType type) { + return unwrap(type).isa(); +} + +MlirType mlirCalibratedQuantizedTypeGet(MlirType expressedType, double min, + double max) { + return wrap( + quant::CalibratedQuantizedType::get(unwrap(expressedType), min, max)); +} + +double mlirCalibratedQuantizedTypeGetMin(MlirType type) { + return unwrap(type).cast().getMin(); +} + +double mlirCalibratedQuantizedTypeGetMax(MlirType type) { + return unwrap(type).cast().getMax(); +} diff --git a/mlir/test/CAPI/CMakeLists.txt b/mlir/test/CAPI/CMakeLists.txt --- a/mlir/test/CAPI/CMakeLists.txt +++ b/mlir/test/CAPI/CMakeLists.txt @@ -58,3 +58,11 @@ MLIRCAPIRegistration MLIRCAPISparseTensor ) + +_add_capi_test_executable(mlir-capi-quant-test + quant.c + LINK_LIBS PRIVATE + MLIRCAPIIR + MLIRCAPIRegistration + MLIRCAPIQuant +) diff --git a/mlir/test/CAPI/quant.c b/mlir/test/CAPI/quant.c new file mode 100644 --- /dev/null +++ b/mlir/test/CAPI/quant.c @@ -0,0 +1,239 @@ +//===- quant.c - Test of Quant dialect C API ------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +// RUN: mlir-capi-quant-test 2>&1 | FileCheck %s + +#include "mlir-c/Dialect/Quant.h" +#include "mlir-c/BuiltinTypes.h" +#include "mlir-c/IR.h" + +#include +#include +#include +#include + +// CHECK-LABEL: testTypeHierarchy +static void testTypeHierarchy(MlirContext ctx) { + fprintf(stderr, "testTypeHierarchy\n"); + + MlirType i8 = mlirIntegerTypeGet(ctx, 8); + MlirType any = mlirTypeParseGet( + ctx, mlirStringRefCreateFromCString("!quant.any:f32>")); + MlirType uniform = + mlirTypeParseGet(ctx, mlirStringRefCreateFromCString( + "!quant.uniform:f32, 0.99872:127>")); + MlirType perAxis = mlirTypeParseGet( + ctx, mlirStringRefCreateFromCString( + "!quant.uniform")); + MlirType calibrated = mlirTypeParseGet( + ctx, + mlirStringRefCreateFromCString("!quant.calibrated>")); + + // The parser itself is checked in C++ dialect tests. + assert(!mlirTypeIsNull(any) && "couldn't parse AnyQuantizedType"); + assert(!mlirTypeIsNull(uniform) && "couldn't parse UniformQuantizedType"); + assert(!mlirTypeIsNull(perAxis) && + "couldn't parse UniformQuantizedPerAxisType"); + assert(!mlirTypeIsNull(calibrated) && + "couldn't parse CalibratedQuantizedType"); + + // CHECK: i8 isa QuantizedType: 0 + fprintf(stderr, "i8 isa QuantizedType: %d\n", mlirTypeIsAQuantizedType(i8)); + // CHECK: any isa QuantizedType: 1 + fprintf(stderr, "any isa QuantizedType: %d\n", mlirTypeIsAQuantizedType(any)); + // CHECK: uniform isa QuantizedType: 1 + fprintf(stderr, "uniform isa QuantizedType: %d\n", + mlirTypeIsAQuantizedType(uniform)); + // CHECK: perAxis isa QuantizedType: 1 + fprintf(stderr, "perAxis isa QuantizedType: %d\n", + mlirTypeIsAQuantizedType(perAxis)); + // CHECK: calibrated isa QuantizedType: 1 + fprintf(stderr, "calibrated isa QuantizedType: %d\n", + mlirTypeIsAQuantizedType(calibrated)); + + // CHECK: any isa AnyQuantizedType: 1 + fprintf(stderr, "any isa AnyQuantizedType: %d\n", + mlirTypeIsAAnyQuantizedType(any)); + // CHECK: uniform isa UniformQuantizedType: 1 + fprintf(stderr, "uniform isa UniformQuantizedType: %d\n", + mlirTypeIsAUniformQuantizedType(uniform)); + // CHECK: perAxis isa UniformQuantizedPerAxisType: 1 + fprintf(stderr, "perAxis isa UniformQuantizedPerAxisType: %d\n", + mlirTypeIsAUniformQuantizedPerAxisType(perAxis)); + // CHECK: calibrated isa CalibratedQuantizedType: 1 + fprintf(stderr, "calibrated isa CalibratedQuantizedType: %d\n", + mlirTypeIsACalibratedQuantizedType(calibrated)); + + // CHECK: perAxis isa UniformQuantizedType: 0 + fprintf(stderr, "perAxis isa UniformQuantizedType: %d\n", + mlirTypeIsAUniformQuantizedType(perAxis)); + // CHECK: uniform isa CalibratedQuantizedType: 0 + fprintf(stderr, "uniform isa CalibratedQuantizedType: %d\n", + mlirTypeIsACalibratedQuantizedType(uniform)); + fprintf(stderr, "\n"); +} + +// CHECK-LABEL: testAnyQuantizedType +void testAnyQuantizedType(MlirContext ctx) { + fprintf(stderr, "testAnyQuantizedType\n"); + + MlirType anyParsed = mlirTypeParseGet( + ctx, mlirStringRefCreateFromCString("!quant.any:f32>")); + + MlirType i8 = mlirIntegerTypeGet(ctx, 8); + MlirType f32 = mlirF32TypeGet(ctx); + MlirType any = + mlirAnyQuantizedTypeGet(mlirQuantizedTypeGetSignedFlag(), i8, f32, -8, 7); + + // CHECK: flags: 1 + fprintf(stderr, "flags: %u\n", mlirQuantizedTypeGetFlags(any)); + // CHECK: signed: 1 + fprintf(stderr, "signed: %u\n", mlirQuantizedTypeIsSigned(any)); + // CHECK: storage type: i8 + fprintf(stderr, "storage type: "); + mlirTypeDump(mlirQuantizedTypeGetStorageType(any)); + fprintf(stderr, "\n"); + // CHECK: expressed type: f32 + fprintf(stderr, "expressed type: "); + mlirTypeDump(mlirQuantizedTypeGetExpressedType(any)); + fprintf(stderr, "\n"); + // CHECK: storage min: -8 + fprintf(stderr, "storage min: %" PRId64 "\n", + mlirQuantizedTypeGetStorageTypeMin(any)); + // CHECK: storage max: 7 + fprintf(stderr, "storage max: %" PRId64 "\n", + mlirQuantizedTypeGetStorageTypeMax(any)); + // CHECK: storage width: 8 + fprintf(stderr, "storage width: %u\n", + mlirQuantizedTypeGetStorageTypeIntegralWidth(any)); + // CHECK: quantized element type: !quant.any:f32> + fprintf(stderr, "quantized element type: "); + mlirTypeDump(mlirQuantizedTypeGetQuantizedElementType(any)); + fprintf(stderr, "\n"); + + // CHECK: equal: 1 + fprintf(stderr, "equal: %d\n", mlirTypeEqual(anyParsed, any)); + // CHECK: !quant.any:f32> + mlirTypeDump(any); + fprintf(stderr, "\n\n"); +} + +// CHECK-LABEL: testUniformType +void testUniformType(MlirContext ctx) { + fprintf(stderr, "testUniformType\n"); + + MlirType uniformParsed = + mlirTypeParseGet(ctx, mlirStringRefCreateFromCString( + "!quant.uniform:f32, 0.99872:127>")); + + MlirType i8 = mlirIntegerTypeGet(ctx, 8); + MlirType f32 = mlirF32TypeGet(ctx); + MlirType uniform = mlirUniformQuantizedTypeGet( + mlirQuantizedTypeGetSignedFlag(), i8, f32, 0.99872, 127, -8, 7); + + // CHECK: scale: 0.998720 + fprintf(stderr, "scale: %lf\n", mlirUniformQuantizedTypeGetScale(uniform)); + // CHECK: zero point: 127 + fprintf(stderr, "zero point: %" PRId64 "\n", + mlirUniformQuantizedTypeGetZeroPoint(uniform)); + // CHECK: fixed point: 0 + fprintf(stderr, "fixed point: %d\n", + mlirUniformQuantizedTypeIsFixedPoint(uniform)); + + // CHECK: equal: 1 + fprintf(stderr, "equal: %d\n", mlirTypeEqual(uniform, uniformParsed)); + // CHECK: !quant.uniform:f32, 9.987200e-01:127> + mlirTypeDump(uniform); + fprintf(stderr, "\n\n"); +} + +// CHECK-LABEL: testUniformPerAxisType +void testUniformPerAxisType(MlirContext ctx) { + fprintf(stderr, "testUniformPerAxisType\n"); + + MlirType perAxisParsed = mlirTypeParseGet( + ctx, mlirStringRefCreateFromCString( + "!quant.uniform")); + + MlirType i8 = mlirIntegerTypeGet(ctx, 8); + MlirType f32 = mlirF32TypeGet(ctx); + double scales[] = {200.0, 0.99872}; + int64_t zeroPoints[] = {0, 120}; + MlirType perAxis = mlirUniformQuantizedPerAxisTypeGet( + mlirQuantizedTypeGetSignedFlag(), i8, f32, + /*nDims=*/2, scales, zeroPoints, + /*quantizedDimension=*/1, + mlirQuantizedTypeGetDefaultMinimumForInteger(/*isSigned=*/true, + /*integralWidth=*/8), + mlirQuantizedTypeGetDefaultMaximumForInteger(/*isSigned=*/true, + /*integralWidth=*/8)); + + // CHECK: num dims: 2 + fprintf(stderr, "num dims: %" PRIdPTR "\n", + mlirUniformQuantizedPerAxisTypeGetNumDims(perAxis)); + // CHECK: scale 0: 200.000000 + fprintf(stderr, "scale 0: %lf\n", + mlirUniformQuantizedPerAxisTypeGetScale(perAxis, 0)); + // CHECK: scale 1: 0.998720 + fprintf(stderr, "scale 1: %lf\n", + mlirUniformQuantizedPerAxisTypeGetScale(perAxis, 1)); + // CHECK: zero point 0: 0 + fprintf(stderr, "zero point 0: %" PRId64 "\n", + mlirUniformQuantizedPerAxisTypeGetZeroPoint(perAxis, 0)); + // CHECK: zero point 1: 120 + fprintf(stderr, "zero point 1: %" PRId64 "\n", + mlirUniformQuantizedPerAxisTypeGetZeroPoint(perAxis, 1)); + // CHECK: quantized dim: 1 + fprintf(stderr, "quantized dim: %" PRId32 "\n", + mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(perAxis)); + // CHECK: fixed point: 0 + fprintf(stderr, "fixed point: %d\n", + mlirUniformQuantizedPerAxisTypeIsFixedPoint(perAxis)); + + // CHECK: equal: 1 + fprintf(stderr, "equal: %d\n", mlirTypeEqual(perAxis, perAxisParsed)); + // CHECK: !quant.uniform + mlirTypeDump(perAxis); + fprintf(stderr, "\n\n"); +} + +// CHECK-LABEL: testCalibratedType +void testCalibratedType(MlirContext ctx) { + fprintf(stderr, "testCalibratedType\n"); + + MlirType calibratedParsed = mlirTypeParseGet( + ctx, + mlirStringRefCreateFromCString("!quant.calibrated>")); + + MlirType f32 = mlirF32TypeGet(ctx); + MlirType calibrated = mlirCalibratedQuantizedTypeGet(f32, -0.998, 1.2321); + + // CHECK: min: -0.998000 + fprintf(stderr, "min: %lf\n", mlirCalibratedQuantizedTypeGetMin(calibrated)); + // CHECK: max: 1.232100 + fprintf(stderr, "max: %lf\n", mlirCalibratedQuantizedTypeGetMax(calibrated)); + + // CHECK: equal: 1 + fprintf(stderr, "equal: %d\n", mlirTypeEqual(calibrated, calibratedParsed)); + // CHECK: !quant.calibrated> + mlirTypeDump(calibrated); + fprintf(stderr, "\n\n"); +} + +int main() { + MlirContext ctx = mlirContextCreate(); + mlirDialectHandleRegisterDialect(mlirGetDialectHandle__quant__(), ctx); + testTypeHierarchy(ctx); + testAnyQuantizedType(ctx); + testUniformType(ctx); + testUniformPerAxisType(ctx); + testCalibratedType(ctx); + mlirContextDestroy(ctx); + return EXIT_SUCCESS; +} diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt --- a/mlir/test/CMakeLists.txt +++ b/mlir/test/CMakeLists.txt @@ -74,6 +74,7 @@ mlir-capi-llvm-test mlir-capi-pass-test mlir-capi-sparse-tensor-test + mlir-capi-quant-test mlir-cpu-runner mlir-linalg-ods-yaml-gen mlir-lsp-server diff --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py --- a/mlir/test/lit.cfg.py +++ b/mlir/test/lit.cfg.py @@ -65,6 +65,7 @@ 'mlir-capi-llvm-test', 'mlir-capi-pass-test', 'mlir-capi-sparse-tensor-test', + 'mlir-capi-quant-test', 'mlir-cpu-runner', 'mlir-linalg-ods-yaml-gen', 'mlir-reduce', diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -496,6 +496,24 @@ ], ) +mlir_c_api_cc_library( + name = "CAPIQuant", + srcs = [ + "lib/CAPI/Dialect/Quant.cpp", + ], + hdrs = [ + "include/mlir-c/Dialect/Quant.h", + ], + header_deps = [ + ":CAPIIRHeaders", + ], + includes = ["include"], + deps = [ + ":CAPIIR", + ":QuantOps", + ], +) + mlir_c_api_cc_library( name = "CAPIConversion", srcs = ["lib/CAPI/Conversion/Passes.cpp"],