diff --git a/mlir/include/mlir/Dialect/Quant/QuantTypes.h b/mlir/include/mlir/Dialect/Quant/QuantTypes.h --- a/mlir/include/mlir/Dialect/Quant/QuantTypes.h +++ b/mlir/include/mlir/Dialect/Quant/QuantTypes.h @@ -15,6 +15,7 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/StandardTypes.h" #include "mlir/IR/Types.h" +#include "mlir/Interfaces/SizedTypeInterface.h" #include "llvm/Support/MathExtras.h" namespace mlir { @@ -171,6 +172,10 @@ /// tensor<4xf32> -> tensor<4xi8> Type castExpressedToStorageType(Type candidateType); + /// Gets the size in bytes with quant storage type. This function will be used + /// by its child classes with SizedTypeInterface; + int64_t getSizeInBytes(); + private: /// Hide the following methods inherited from `Type`. It is almost certainly /// a bug to call them from a `QuantizedType` object. Users should call @@ -195,7 +200,8 @@ /// Note that for the any type, the expressed type is optional. class AnyQuantizedType : public Type::TypeBase { + detail::AnyQuantizedTypeStorage, + mlir::SizedTypeInterface::Trait> { public: using Base::Base; @@ -252,7 +258,8 @@ /// ZeroPoint: An integer value class UniformQuantizedType : public Type::TypeBase { + detail::UniformQuantizedTypeStorage, + SizedTypeInterface::Trait> { public: using Base::Base; @@ -309,7 +316,8 @@ /// ZeroPoint: An integer value class UniformQuantizedPerAxisType : public Type::TypeBase { + detail::UniformQuantizedPerAxisTypeStorage, + SizedTypeInterface::Trait> { public: using Base::Base; diff --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h --- a/mlir/include/mlir/IR/StandardTypes.h +++ b/mlir/include/mlir/IR/StandardTypes.h @@ -456,9 +456,7 @@ using ShapedType::ShapedType; /// Return true if the specified element type is ok in a memref. - static bool isValidElementType(Type type) { - return type.isIntOrIndexOrFloat() || type.isa(); - } + static bool isValidElementType(Type type); /// Methods for support type inquiry through isa, cast, and dyn_cast. static bool classof(Type type); diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt --- a/mlir/include/mlir/Interfaces/CMakeLists.txt +++ b/mlir/include/mlir/Interfaces/CMakeLists.txt @@ -5,6 +5,7 @@ add_mlir_interface(InferTypeOpInterface) add_mlir_interface(LoopLikeInterface) add_mlir_interface(SideEffectInterfaces) +add_mlir_interface(SizedTypeInterface) add_mlir_interface(VectorInterfaces) add_mlir_interface(ViewLikeInterface) diff --git a/mlir/include/mlir/Interfaces/SizedTypeInterface.h b/mlir/include/mlir/Interfaces/SizedTypeInterface.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Interfaces/SizedTypeInterface.h @@ -0,0 +1,21 @@ +//===- SizedTypeInterface.td - SizedType interface ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the interface for type to explictly have size in bytes. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_SIZEDTYPEINTERFACE_H_ +#define MLIR_INTERFACES_SIZEDTYPEINTERFACE_H_ + +#include "mlir/IR/OpDefinition.h" + +/// Include the generated interface declarations. +#include "mlir/Interfaces/SizedTypeInterface.h.inc" + +#endif // MLIR_INTERFACES_SIZEDTYPEINTERFACE_H_ diff --git a/mlir/include/mlir/Interfaces/SizedTypeInterface.td b/mlir/include/mlir/Interfaces/SizedTypeInterface.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Interfaces/SizedTypeInterface.td @@ -0,0 +1,31 @@ +//===- SizedTypeInterface.td - SizedType interface ---------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines the interface for type to explictly have size in bytes. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_SIZEDTYPEINTERFACE +#define MLIR_INTERFACES_SIZEDTYPEINTERFACE + +include "mlir/IR/OpBase.td" + +def SizedTypeInterface : TypeInterface<"SizedTypeInterface"> { + let description = [{ + An interface for types to provide size of bytes in memory. + }]; + let cppNamespace = "::mlir"; + + let methods = [ + InterfaceMethod< + "Returns the type size of bytes in memory.", + "int64_t", "getSizeInBytes"> + ]; +} + +#endif // MLIR_INTERFACES_SIZEDTYPEINTERFACE diff --git a/mlir/lib/Dialect/Quant/CMakeLists.txt b/mlir/lib/Dialect/Quant/CMakeLists.txt --- a/mlir/lib/Dialect/Quant/CMakeLists.txt +++ b/mlir/lib/Dialect/Quant/CMakeLists.txt @@ -20,6 +20,7 @@ MLIRIR MLIRPass MLIRSideEffectInterfaces + MLIRSizedTypeInterface MLIRSupport MLIRStandard MLIRTransformUtils diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp --- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp @@ -200,6 +200,10 @@ return QuantizedType::castToStorageType(expressedQuantizedType); } +int64_t QuantizedType::getSizeInBytes() { + return getStorageType().getIntOrFloatBitWidth() / CHAR_BIT; +} + AnyQuantizedType AnyQuantizedType::get(unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt --- a/mlir/lib/IR/CMakeLists.txt +++ b/mlir/lib/IR/CMakeLists.txt @@ -34,6 +34,7 @@ DEPENDS MLIRCallInterfacesIncGen MLIROpAsmInterfaceIncGen + MLIRSizedTypeInterfaceIncGen MLIRSymbolInterfacesIncGen MLIRRegionKindInterfaceIncGen diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp --- a/mlir/lib/IR/StandardTypes.cpp +++ b/mlir/lib/IR/StandardTypes.cpp @@ -12,6 +12,7 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" +#include "mlir/Interfaces/SizedTypeInterface.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/Twine.h" @@ -363,6 +364,15 @@ // BaseMemRefType //===----------------------------------------------------------------------===// +/// Return true if the specified element type is ok in a memref. +bool BaseMemRefType::isValidElementType(Type type) { + // Note: Non standard/builtin types are allowed to exist within memref + // types. Dialects are expected to verify that memref types have a valid + // element type within that dialect. + return type.isIntOrIndexOrFloat() || type.isa() || + type.dyn_cast(); +} + unsigned BaseMemRefType::getMemorySpace() const { return static_cast(impl)->memorySpace; } diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt --- a/mlir/lib/Interfaces/CMakeLists.txt +++ b/mlir/lib/Interfaces/CMakeLists.txt @@ -6,6 +6,7 @@ InferTypeOpInterface.cpp LoopLikeInterface.cpp SideEffectInterfaces.cpp + SizedTypeInterface.cpp VectorInterfaces.cpp ViewLikeInterface.cpp ) @@ -33,6 +34,7 @@ add_mlir_interface_library(InferTypeOpInterface) add_mlir_interface_library(LoopLikeInterface) add_mlir_interface_library(SideEffectInterfaces) +add_mlir_interface_library(SizedTypeInterface) add_mlir_interface_library(VectorInterfaces) add_mlir_interface_library(ViewLikeInterface) diff --git a/mlir/lib/Interfaces/SizedTypeInterface.cpp b/mlir/lib/Interfaces/SizedTypeInterface.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Interfaces/SizedTypeInterface.cpp @@ -0,0 +1,18 @@ +//===- SizedTypeInterface.td - SizedType interface ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Interfaces/SizedTypeInterface.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// SizedType Interface +//===----------------------------------------------------------------------===// + +/// Include the definitions of the SizedType Interface. +#include "mlir/Interfaces/SizedTypeInterface.cpp.inc" diff --git a/mlir/lib/Parser/TypeParser.cpp b/mlir/lib/Parser/TypeParser.cpp --- a/mlir/lib/Parser/TypeParser.cpp +++ b/mlir/lib/Parser/TypeParser.cpp @@ -217,8 +217,7 @@ return nullptr; // Check that memref is formed from allowed types. - if (!elementType.isIntOrIndexOrFloat() && - !elementType.isa()) + if (!BaseMemRefType::isValidElementType(elementType)) return emitError(typeLoc, "invalid memref element type"), nullptr; // Parse semi-affine-map-composition. diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -78,6 +78,9 @@ // CHECK: func @memrefs(memref<1x?x4x?x?xi32, #map{{[0-9]+}}>, memref<8xi8>) func @memrefs(memref<1x?x4x?x?xi32, #map0>, memref<8xi8, #map1, #map1>) +// CHECK: func @memref_quant_element(memref<3x!quant.uniform>) +func @memref_quant_element(memref<3x!quant.uniform>) + // Test memref affine map compositions. // CHECK: func @memrefs2(memref<2x4x8xi8, 1>)