diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h --- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h @@ -309,6 +309,8 @@ return *this; } + py::object get_class() const { return thisClass; } + protected: py::object superClass; py::object thisClass; diff --git a/mlir/lib/Bindings/Python/DialectQuant.cpp b/mlir/lib/Bindings/Python/DialectQuant.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Bindings/Python/DialectQuant.cpp @@ -0,0 +1,307 @@ +//===- DialectQuant.cpp - 'quant' dialect submodule -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "Dialects.h" +#include "mlir-c/Dialect/Quant.h" +#include "mlir-c/IR.h" +#include "mlir/Bindings/Python/PybindAdaptors.h" + +namespace py = pybind11; +using namespace llvm; +using namespace mlir; +using namespace mlir::python::adaptors; + +void mlir::python::populateDialectQuantSubmodule(const py::module &m, + const py::module &irModule) { + auto typeClass = irModule.attr("Type"); + + //===-------------------------------------------------------------------===// + // QuantizedType + //===-------------------------------------------------------------------===// + + auto quantizedType = mlir_type_subclass(m, "QuantizedType", + mlirTypeIsAQuantizedType, typeClass); + quantizedType.def_staticmethod( + "default_minimum_for_integer", + [](bool isSigned, unsigned integralWidth) { + return mlirQuantizedTypeGetDefaultMinimumForInteger(isSigned, + integralWidth); + }, + "Default minimum value for the integer with the specified signedness and " + "bit width.", + py::arg("is_signed"), py::arg("integral_width")); + quantizedType.def_staticmethod( + "default_maximum_for_integer", + [](bool isSigned, unsigned integralWidth) { + return mlirQuantizedTypeGetDefaultMaximumForInteger(isSigned, + integralWidth); + }, + "Default maximum value for the integer with the specified signedness and " + "bit width.", + py::arg("is_signed"), py::arg("integral_width")); + quantizedType.def_property_readonly( + "expressed_type", + [](MlirType type) { return mlirQuantizedTypeGetExpressedType(type); }, + "Type expressed by this quantized type."); + quantizedType.def_property_readonly( + "flags", [](MlirType type) { return mlirQuantizedTypeGetFlags(type); }, + "Flags of this quantized type (named accessors should be preferred to " + "this)"); + quantizedType.def_property_readonly( + "is_signed", + [](MlirType type) { return mlirQuantizedTypeIsSigned(type); }, + "Signedness of this quantized type."); + quantizedType.def_property_readonly( + "storage_type", + [](MlirType type) { return mlirQuantizedTypeGetStorageType(type); }, + "Storage type backing this quantized type."); + quantizedType.def_property_readonly( + "storage_type_min", + [](MlirType type) { return mlirQuantizedTypeGetStorageTypeMin(type); }, + "The minimum value held by the storage type of this quantized type."); + quantizedType.def_property_readonly( + "storage_type_max", + [](MlirType type) { return mlirQuantizedTypeGetStorageTypeMax(type); }, + "The maximum value held by the storage type of this quantized type."); + quantizedType.def_property_readonly( + "storage_type_integral_width", + [](MlirType type) { + return mlirQuantizedTypeGetStorageTypeIntegralWidth(type); + }, + "The bitwidth of the storage type of this quantized type."); + quantizedType.def( + "is_compatible_expressed_type", + [](MlirType type, MlirType candidate) { + return mlirQuantizedTypeIsCompatibleExpressedType(type, candidate); + }, + "Checks whether the candidate type can be expressed by this quantized " + "type.", + py::arg("candidate")); + quantizedType.def_property_readonly( + "quantized_element_type", + [](MlirType type) { + return mlirQuantizedTypeGetQuantizedElementType(type); + }, + "Element type of this quantized type expressed as quantized type."); + quantizedType.def( + "cast_from_storage_type", + [](MlirType type, MlirType candidate) { + MlirType castResult = + mlirQuantizedTypeCastFromStorageType(type, candidate); + if (!mlirTypeIsNull(castResult)) + return castResult; + throw py::type_error("Invalid cast."); + }, + "Casts from a type based on the storage type of this quantized type to a " + "corresponding type based on the quantized type. Raises TypeError if the " + "cast is not valid.", + py::arg("candidate")); + quantizedType.def_staticmethod( + "cast_to_storage_type", + [](MlirType type) { + MlirType castResult = mlirQuantizedTypeCastToStorageType(type); + if (!mlirTypeIsNull(castResult)) + return castResult; + throw py::type_error("Invalid cast."); + }, + "Casts from a type based on a quantized type to a corresponding type " + "based on the storage type of this quantized type. Raises TypeError if " + "the cast is not valid.", + py::arg("type")); + quantizedType.def( + "cast_from_expressed_type", + [](MlirType type, MlirType candidate) { + MlirType castResult = + mlirQuantizedTypeCastFromExpressedType(type, candidate); + if (!mlirTypeIsNull(castResult)) + return castResult; + throw py::type_error("Invalid cast."); + }, + "Casts from a type based on the expressed type of this quantized type to " + "a corresponding type based on the quantized type. Raises TypeError if " + "the cast is not valid.", + py::arg("candidate")); + quantizedType.def_staticmethod( + "cast_to_expressed_type", + [](MlirType type) { + MlirType castResult = mlirQuantizedTypeCastToExpressedType(type); + if (!mlirTypeIsNull(castResult)) + return castResult; + throw py::type_error("Invalid cast."); + }, + "Casts from a type based on a quantized type to a corresponding type " + "based on the expressed type of this quantized type. Raises TypeError if " + "the cast is not valid.", + py::arg("type")); + quantizedType.def( + "cast_expressed_to_storage_type", + [](MlirType type, MlirType candidate) { + MlirType castResult = + mlirQuantizedTypeCastExpressedToStorageType(type, candidate); + if (!mlirTypeIsNull(castResult)) + return castResult; + throw py::type_error("Invalid cast."); + }, + "Casts from a type based on the expressed type of this quantized type to " + "a corresponding type based on the storage type. Raises TypeError if the " + "cast is not valid.", + py::arg("candidate")); + + quantizedType.get_class().attr("FLAG_SIGNED") = + mlirQuantizedTypeGetSignedFlag(); + + //===-------------------------------------------------------------------===// + // AnyQuantizedType + //===-------------------------------------------------------------------===// + + auto anyQuantizedType = + mlir_type_subclass(m, "AnyQuantizedType", mlirTypeIsAAnyQuantizedType, + quantizedType.get_class()); + anyQuantizedType.def_classmethod( + "get", + [](py::object cls, unsigned flags, MlirType storageType, + MlirType expressedType, int64_t storageTypeMin, + int64_t storageTypeMax) { + return cls(mlirAnyQuantizedTypeGet(flags, storageType, expressedType, + storageTypeMin, storageTypeMax)); + }, + "Gets an instance of AnyQuantizedType in the same context as the " + "provided storage type.", + py::arg("cls"), py::arg("flags"), py::arg("storage_type"), + py::arg("expressed_type"), py::arg("storage_type_min"), + py::arg("storage_type_max")); + + //===-------------------------------------------------------------------===// + // UniformQuantizedType + //===-------------------------------------------------------------------===// + + auto uniformQuantizedType = mlir_type_subclass( + m, "UniformQuantizedType", mlirTypeIsAUniformQuantizedType, + quantizedType.get_class()); + uniformQuantizedType.def_classmethod( + "get", + [](py::object cls, unsigned flags, MlirType storageType, + MlirType expressedType, double scale, int64_t zeroPoint, + int64_t storageTypeMin, int64_t storageTypeMax) { + return cls(mlirUniformQuantizedTypeGet(flags, storageType, + expressedType, scale, zeroPoint, + storageTypeMin, storageTypeMax)); + }, + "Gets an instance of UniformQuantizedType in the same context as the " + "provided storage type.", + py::arg("cls"), py::arg("flags"), py::arg("storage_type"), + py::arg("expressed_type"), py::arg("scale"), py::arg("zero_point"), + py::arg("storage_type_min"), py::arg("storage_type_max")); + uniformQuantizedType.def_property_readonly( + "scale", + [](MlirType type) { return mlirUniformQuantizedTypeGetScale(type); }, + "The scale designates the difference between the real values " + "corresponding to consecutive quantized values differing by 1."); + uniformQuantizedType.def_property_readonly( + "zero_point", + [](MlirType type) { return mlirUniformQuantizedTypeGetZeroPoint(type); }, + "The storage value corresponding to the real value 0 in the affine " + "equation."); + uniformQuantizedType.def_property_readonly( + "is_fixed_point", + [](MlirType type) { return mlirUniformQuantizedTypeIsFixedPoint(type); }, + "Fixed point values are real numbers divided by a scale."); + + //===-------------------------------------------------------------------===// + // UniformQuantizedPerAxisType + //===-------------------------------------------------------------------===// + auto uniformQuantizedPerAxisType = mlir_type_subclass( + m, "UniformQuantizedPerAxisType", mlirTypeIsAUniformQuantizedPerAxisType, + quantizedType.get_class()); + uniformQuantizedPerAxisType.def_classmethod( + "get", + [](py::object cls, unsigned flags, MlirType storageType, + MlirType expressedType, std::vector scales, + std::vector zeroPoints, int32_t quantizedDimension, + int64_t storageTypeMin, int64_t storageTypeMax) { + if (scales.size() != zeroPoints.size()) + throw py::value_error( + "Mismatching number of scales and zero points."); + auto nDims = static_cast(scales.size()); + return cls(mlirUniformQuantizedPerAxisTypeGet( + flags, storageType, expressedType, nDims, scales.data(), + zeroPoints.data(), quantizedDimension, storageTypeMin, + storageTypeMax)); + }, + "Gets an instance of UniformQuantizedPerAxisType in the same context as " + "the provided storage type.", + py::arg("cls"), py::arg("flags"), py::arg("storage_type"), + py::arg("expressed_type"), py::arg("scales"), py::arg("zero_points"), + py::arg("quantized_dimension"), py::arg("storage_type_min"), + py::arg("storage_type_max")); + uniformQuantizedPerAxisType.def_property_readonly( + "scales", + [](MlirType type) { + intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type); + std::vector scales; + scales.reserve(nDim); + for (intptr_t i = 0; i < nDim; ++i) { + double scale = mlirUniformQuantizedPerAxisTypeGetScale(type, i); + scales.push_back(scale); + } + }, + "The scales designate the difference between the real values " + "corresponding to consecutive quantized values differing by 1. The ith " + "scale corresponds to the ith slice in the quantized_dimension."); + uniformQuantizedPerAxisType.def_property_readonly( + "zero_points", + [](MlirType type) { + intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type); + std::vector zeroPoints; + zeroPoints.reserve(nDim); + for (intptr_t i = 0; i < nDim; ++i) { + int64_t zeroPoint = + mlirUniformQuantizedPerAxisTypeGetZeroPoint(type, i); + zeroPoints.push_back(zeroPoint); + } + }, + "the storage values corresponding to the real value 0 in the affine " + "equation. The ith zero point corresponds to the ith slice in the " + "quantized_dimension."); + uniformQuantizedPerAxisType.def_property_readonly( + "quantized_dimension", + [](MlirType type) { + return mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(type); + }, + "Specifies the dimension of the shape that the scales and zero points " + "correspond to."); + uniformQuantizedPerAxisType.def_property_readonly( + "is_fixed_point", + [](MlirType type) { + return mlirUniformQuantizedPerAxisTypeIsFixedPoint(type); + }, + "Fixed point values are real numbers divided by a scale."); + + //===-------------------------------------------------------------------===// + // CalibratedQuantizedType + //===-------------------------------------------------------------------===// + + auto calibratedQuantizedType = mlir_type_subclass( + m, "CalibratedQuantizedType", mlirTypeIsACalibratedQuantizedType, + quantizedType.get_class()); + calibratedQuantizedType.def_classmethod( + "get", + [](py::object cls, MlirType expressedType, double min, double max) { + return cls(mlirCalibratedQuantizedTypeGet(expressedType, min, max)); + }, + "Gets an instance of CalibratedQuantizedType in the same context as the " + "provided expressed type.", + py::arg("cls"), py::arg("expressed_type"), py::arg("min"), + py::arg("max")); + calibratedQuantizedType.def_property_readonly("min", [](MlirType type) { + return mlirCalibratedQuantizedTypeGetMin(type); + }); + calibratedQuantizedType.def_property_readonly("max", [](MlirType type) { + return mlirCalibratedQuantizedTypeGetMax(type); + }); +} diff --git a/mlir/lib/Bindings/Python/Dialects.h b/mlir/lib/Bindings/Python/Dialects.h --- a/mlir/lib/Bindings/Python/Dialects.h +++ b/mlir/lib/Bindings/Python/Dialects.h @@ -17,6 +17,8 @@ void populateDialectLinalgSubmodule(pybind11::module m); void populateDialectSparseTensorSubmodule(const pybind11::module &m, const pybind11::module &irModule); +void populateDialectQuantSubmodule(const pybind11::module &m, + const pybind11::module &irModule); } // namespace python } // namespace mlir diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -107,4 +107,6 @@ populateDialectLinalgSubmodule(linalgModule); populateDialectSparseTensorSubmodule( dialectsModule.def_submodule("sparse_tensor"), irModule); + populateDialectQuantSubmodule(dialectsModule.def_submodule("quant"), + irModule); } diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -25,6 +25,8 @@ _mlir_libs/_mlir/__init__.pyi _mlir_libs/_mlir/ir.pyi _mlir_libs/_mlir/passmanager.pyi + # TODO: this should be split out into a separate library. + _mlir_libs/_mlir/dialects/quant.pyi ) declare_mlir_python_sources(MLIRPythonSources.ExecutionEngine @@ -115,6 +117,13 @@ dialects/_memref_ops_ext.py DIALECT_NAME memref) +declare_mlir_python_sources( + MLIRPythonSources.Dialects.quant + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + SOURCES + dialects/quant.py) + declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" @@ -184,6 +193,7 @@ SOURCES DialectLinalg.cpp # TODO: Break this out. DialectSparseTensor.cpp # TODO: Break this out. + DialectQuant.cpp # TODO: Break this out. MainModule.cpp IRAffine.cpp IRAttributes.cpp @@ -212,6 +222,7 @@ MLIRCAPILinalg # TODO: Remove when above is removed. MLIRCAPISparseTensor # TODO: Remove when above is removed. MLIRCAPIStandard + MLIRCAPIQuant # TODO: Remove when above is removed. ) declare_mlir_python_extension(MLIRPythonExtension.AllPassesRegistration diff --git a/mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi b/mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi new file mode 100644 --- /dev/null +++ b/mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi @@ -0,0 +1,123 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import List + +from mlir.ir import Type + +__all__ = [ + "QuantizedType", + "AnyQuantizedType", + "UniformQuantizedType", + "UniformQuantizedPerAxisType", + "CalibratedQuantizedType", +] + +class QuantizedType(Type): + @staticmethod + def isinstance(type: Type) -> bool: ... + + @staticmethod + def default_minimum_for_integer(is_signed: bool, integral_width: int) -> int: + ... + + @staticmethod + def default_maximum_for_integer(is_signed: bool, integral_width: int) -> int: + ... + + @property + def expressed_type(self) -> Type: ... + + @property + def flags(self) -> int: ... + + @property + def is_signed(self) -> bool: ... + + @property + def storage_type(self) -> Type: ... + + @property + def storage_type_min(self) -> int: ... + + @property + def storage_type_max(self) -> int: ... + + @property + def storage_type_integral_width(self) -> int: ... + + def is_compatible_expressed_type(self, candidate: Type) -> bool: ... + + @property + def quantized_element_type(self) -> Type: ... + + def cast_from_storage_type(self, candidate: Type) -> Type: ... + + @staticmethod + def cast_to_storage_type(type: Type) -> Type: ... + + def cast_from_expressed_type(self, candidate: Type) -> Type: ... + + @staticmethod + def cast_to_expressed_type(type: Type) -> Type: ... + + def cast_expressed_to_storage_type(self, candidate: Type) -> Type: ... + + +class AnyQuantizedType(QuantizedType): + + @classmethod + def get(cls, flags: int, storage_type: Type, expressed_type: Type, + storage_type_min: int, storage_type_max: int) -> Type: + ... + + +class UniformQuantizedType(QuantizedType): + + @classmethod + def get(cls, flags: int, storage_type: Type, expressed_type: Type, + scale: float, zero_point: int, storage_type_min: int, + storage_type_max: int) -> Type: ... + + @property + def scale(self) -> float: ... + + @property + def zero_point(self) -> int: ... + + @property + def is_fixed_point(self) -> bool: ... + + +class UniformQuantizedPerAxisType(QuantizedType): + + @classmethod + def get(cls, flags: int, storage_type: Type, expressed_type: Type, + scales: List[float], zero_points: List[int], quantized_dimension: int, + storage_type_min: int, storage_type_max: int): + ... + + @property + def scales(self) -> List[float]: ... + + @property + def zero_points(self) -> List[float]: ... + + @property + def quantized_dimension(self) -> int: ... + + @property + def is_fixed_point(self) -> bool: ... + + +def CalibratedQuantizedType(QuantizedType): + + @classmethod + def get(cls, expressed_type: Type, min: float, max: float): ... + + @property + def min(self) -> float: ... + + @property + def max(self) -> float: ... \ No newline at end of file diff --git a/mlir/python/mlir/dialects/quant.py b/mlir/python/mlir/dialects/quant.py new file mode 100644 --- /dev/null +++ b/mlir/python/mlir/dialects/quant.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .._mlir_libs._mlir.dialects.quant import * diff --git a/mlir/test/python/dialects/quant.py b/mlir/test/python/dialects/quant.py new file mode 100644 --- /dev/null +++ b/mlir/test/python/dialects/quant.py @@ -0,0 +1,131 @@ +# RUN: %PYTHON %s | FileCheck %s + +from mlir.ir import * +from mlir.dialects import quant + + +def run(f): + print("\nTEST:", f.__name__) + f() + return f + + +# CHECK-LABEL: TEST: test_type_hierarchy +@run +def test_type_hierarchy(): + with Context(): + i8 = IntegerType.get_signless(8) + any = Type.parse("!quant.any:f32>") + uniform = Type.parse("!quant.uniform:f32, 0.99872:127>") + per_axis = Type.parse("!quant.uniform") + calibrated = Type.parse("!quant.calibrated>") + + assert not quant.QuantizedType.isinstance(i8) + assert quant.QuantizedType.isinstance(any) + assert quant.QuantizedType.isinstance(uniform) + assert quant.QuantizedType.isinstance(per_axis) + assert quant.QuantizedType.isinstance(calibrated) + + assert quant.AnyQuantizedType.isinstance(any) + assert quant.UniformQuantizedType.isinstance(uniform) + assert quant.UniformQuantizedPerAxisType.isinstance(per_axis) + assert quant.CalibratedQuantizedType.isinstance(calibrated) + + assert not quant.AnyQuantizedType.isinstance(uniform) + assert not quant.UniformQuantizedType.isinstance(per_axis) + + +# CHECK-LABEL: TEST: test_any_quantized_type +@run +def test_any_quantized_type(): + with Context(): + i8 = IntegerType.get_signless(8) + f32 = F32Type.get() + any = quant.AnyQuantizedType.get(quant.QuantizedType.FLAG_SIGNED, i8, f32, + -8, 7) + + # CHECK: flags: 1 + print(f"flags: {any.flags}") + # CHECK: signed: True + print(f"signed: {any.is_signed}") + # CHECK: storage type: i8 + print(f"storage type: {any.storage_type}") + # CHECK: expressed type: f32 + print(f"expressed type: {any.expressed_type}") + # CHECK: storage min: -8 + print(f"storage min: {any.storage_type_min}") + # CHECK: storage max: 7 + print(f"storage max: {any.storage_type_max}") + # CHECK: storage width: 8 + print(f"storage width: {any.storage_type_integral_width}") + # CHECK: quantized element type: !quant.any:f32> + print(f"quantized element type: {any.quantized_element_type}") + # CHECK: !quant.any:f32> + print(any) + assert any == Type.parse("!quant.any:f32>") + + +# CHECK-LABEL: TEST: test_uniform_type +@run +def test_uniform_type(): + with Context(): + i8 = IntegerType.get_signless(8) + f32 = F32Type.get() + uniform = quant.UniformQuantizedType.get( + quant.UniformQuantizedType.FLAG_SIGNED, i8, f32, 0.99872, 127, -8, 7) + + # CHECK: scale: 0.99872 + print(f"scale: {uniform.scale}") + # CHECK: zero point: 127 + print(f"zero point: {uniform.zero_point}") + # CHECK: fixed point: False + print(f"fixed point: {uniform.is_fixed_point}") + # CHECK: !quant.uniform:f32, 9.987200e-01:127> + print(uniform) + assert uniform == Type.parse("!quant.uniform:f32, 0.99872:127>") + + +# CHECK-LABEL: TEST: test_uniform_per_axis_type +@run +def test_uniform_per_axis_type(): + with Context(): + i8 = IntegerType.get_signless(8) + f32 = F32Type.get() + per_axis = quant.UniformQuantizedPerAxisType.get( + quant.QuantizedType.FLAG_SIGNED, + i8, + f32, [200, 0.99872], [0, 120], + quantized_dimension=1, + storage_type_min=quant.QuantizedType.default_minimum_for_integer( + is_signed=True, integral_width=8), + storage_type_max=quant.QuantizedType.default_maximum_for_integer( + is_signed=True, integral_width=8)) + + # CHECK: scales: None + print(f"scales: {per_axis.scales}") + # CHECK: zero_points: None + print(f"zero_points: {per_axis.zero_points}") + # CHECK: quantized dim: 1 + print(f"quantized dim: {per_axis.quantized_dimension}") + # CHECK: fixed point: False + print(f"fixed point: {per_axis.is_fixed_point}") + # CHECK: !quant.uniform + print(per_axis) + assert per_axis == Type.parse( + "!quant.uniform") + + +# CHECK-LABEL: TEST: test_calibrated_type +@run +def test_calibrated_type(): + with Context(): + f32 = F32Type.get() + calibrated = quant.CalibratedQuantizedType.get(f32, -0.998, 1.2321) + + # CHECK: min: -0.998 + print(f"min: {calibrated.min}") + # CHECK: max: 1.2321 + print(f"max: {calibrated.max}") + # CHECK: !quant.calibrated> + print(calibrated) + assert calibrated == Type.parse("!quant.calibrated>") diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -650,6 +650,7 @@ MLIR_PYTHON_BINDINGS_SOURCES = [ "lib/Bindings/Python/DialectLinalg.cpp", "lib/Bindings/Python/DialectSparseTensor.cpp", + "lib/Bindings/Python/DialectQuant.cpp", "lib/Bindings/Python/IRAffine.cpp", "lib/Bindings/Python/IRAttributes.cpp", "lib/Bindings/Python/IRCore.cpp", @@ -683,6 +684,7 @@ ":CAPIIR", ":CAPIInterfaces", ":CAPILinalg", + ":CAPIQuant", ":CAPIRegistration", ":CAPISparseTensor", ":MLIRBindingsPythonHeadersAndDeps", @@ -714,6 +716,7 @@ ":CAPIGPUHeaders", ":CAPIIRHeaders", ":CAPILinalgHeaders", + ":CAPIQuantHeaders", ":CAPIRegistrationHeaders", ":CAPISparseTensorHeaders", ":MLIRBindingsPythonHeaders", @@ -738,6 +741,7 @@ ":CAPIIRObjects", ":CAPIInterfacesObjects", ":CAPILinalgObjects", + ":CAPIQuantObjects", ":CAPIRegistrationObjects", ":CAPISparseTensorObjects", ], diff --git a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel @@ -317,6 +317,17 @@ ], ) +##---------------------------------------------------------------------------## +# Quant dialect. +##---------------------------------------------------------------------------## + +filegroup( + name = "QuantPyFiles", + srcs = [ + "mlir/dialects/quant.py", + ], +) + ##---------------------------------------------------------------------------## # PythonTest dialect. ##---------------------------------------------------------------------------## @@ -607,4 +618,4 @@ "mlir/dialects/vector.py", ":VectorOpsPyGen", ], -) \ No newline at end of file +)