diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md --- a/mlir/docs/OpDefinitions.md +++ b/mlir/docs/OpDefinitions.md @@ -1525,11 +1525,10 @@ - If the `genAccessors` field is 1 (the default) accessor methods will be generated on the Type class (e.g. `int getWidth() const` in the example above). -- If the `genVerifyInvariantsDecl` field is set, a declaration for a method - `static LogicalResult verifyConstructionInvariants(Location, parameters...)` - is added to the class as well as a `getChecked(Location, parameters...)` - method which gets the result of `verifyConstructionInvariants` before - calling `get`. +- If the `genVerifyDecl` field is set, a declaration for a method `static + LogicalResult verify(emitErrorFn, parameters...)` is added to the class as + well as a `getChecked(emitErrorFn, parameters...)` method which checks the + result of `verify` before calling `get`. - The `storageClass` field can be used to set the name of the storage class. - The `storageNamespace` field is used to set the namespace where the storage class should sit. Defaults to "detail". @@ -1555,9 +1554,9 @@ // given set of parameters. static MyType get(MLIRContext *context, int intParam); -// If `genVerifyInvariantsDecl` is set to 1, the following method is also -// generated. -static MyType getChecked(Location loc, int intParam); +// If `genVerifyDecl` is set to 1, the following method is also generated. +static MyType getChecked(function_ref emitError, + MLIRContext *context, int intParam); ``` If these autogenerated methods are not desired, such as when they conflict with diff --git a/mlir/docs/Tutorials/DefiningAttributesAndTypes.md b/mlir/docs/Tutorials/DefiningAttributesAndTypes.md --- a/mlir/docs/Tutorials/DefiningAttributesAndTypes.md +++ b/mlir/docs/Tutorials/DefiningAttributesAndTypes.md @@ -161,27 +161,28 @@ return Base::get(type.getContext(), param, type); } - /// This method is used to get an instance of the 'ComplexType', defined at - /// the given location. If any of the construction invariants are invalid, - /// errors are emitted with the provided location and a null type is returned. + /// This method is used to get an instance of the 'ComplexType'. If any of the + /// construction invariants are invalid, errors are emitted with the provided + /// `emitError` function and a null type is returned. /// Note: This method is completely optional. - static ComplexType getChecked(unsigned param, Type type, Location location) { + static ComplexType getChecked(function_ref emitError, + unsigned param, Type type) { // Call into a helper 'getChecked' method in 'TypeBase' to get a uniqued // instance of this type. All parameters to the storage class are passed - // after the location. - return Base::getChecked(location, param, type); + // after the context. + return Base::getChecked(emitError, type.getContext(), param, type); } /// This method is used to verify the construction invariants passed into the /// 'get' and 'getChecked' methods. Note: This method is completely optional. - static LogicalResult verifyConstructionInvariants( - Location loc, unsigned param, Type type) { + static LogicalResult verify(function_ref emitError, + unsigned param, Type type) { // Our type only allows non-zero parameters. if (param == 0) - return emitError(loc) << "non-zero parameter passed to 'ComplexType'"; + return emitError() << "non-zero parameter passed to 'ComplexType'"; // Our type also expects an integer type. if (!type.isa()) - return emitError(loc) << "non integer-type passed to 'ComplexType'"; + return emitError() << "non integer-type passed to 'ComplexType'"; return success(); } diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h --- a/mlir/include/mlir-c/BuiltinAttributes.h +++ b/mlir/include/mlir-c/BuiltinAttributes.h @@ -98,8 +98,9 @@ /// Same as "mlirFloatAttrDoubleGet", but if the type is not valid for a /// construction of a FloatAttr, returns a null MlirAttribute. -MLIR_CAPI_EXPORTED MlirAttribute -mlirFloatAttrDoubleGetChecked(MlirType type, double value, MlirLocation loc); +MLIR_CAPI_EXPORTED MlirAttribute mlirFloatAttrDoubleGetChecked(MlirLocation loc, + MlirType type, + double value); /// Returns the value stored in the given floating point attribute, interpreting /// the value as double. diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h --- a/mlir/include/mlir-c/BuiltinTypes.h +++ b/mlir/include/mlir-c/BuiltinTypes.h @@ -170,10 +170,10 @@ /// Same as "mlirVectorTypeGet" but returns a nullptr wrapping MlirType on /// illegal arguments, emitting appropriate diagnostics. -MLIR_CAPI_EXPORTED MlirType mlirVectorTypeGetChecked(intptr_t rank, +MLIR_CAPI_EXPORTED MlirType mlirVectorTypeGetChecked(MlirLocation loc, + intptr_t rank, const int64_t *shape, - MlirType elementType, - MlirLocation loc); + MlirType elementType); //===----------------------------------------------------------------------===// // Ranked / Unranked Tensor type. @@ -196,10 +196,9 @@ /// Same as "mlirRankedTensorTypeGet" but returns a nullptr wrapping MlirType on /// illegal arguments, emitting appropriate diagnostics. -MLIR_CAPI_EXPORTED MlirType mlirRankedTensorTypeGetChecked(intptr_t rank, - const int64_t *shape, - MlirType elementType, - MlirLocation loc); +MLIR_CAPI_EXPORTED MlirType +mlirRankedTensorTypeGetChecked(MlirLocation loc, intptr_t rank, + const int64_t *shape, MlirType elementType); /// Creates an unranked tensor type with the given element type in the same /// context as the element type. The type is owned by the context. @@ -208,7 +207,7 @@ /// Same as "mlirUnrankedTensorTypeGet" but returns a nullptr wrapping MlirType /// on illegal arguments, emitting appropriate diagnostics. MLIR_CAPI_EXPORTED MlirType -mlirUnrankedTensorTypeGetChecked(MlirType elementType, MlirLocation loc); +mlirUnrankedTensorTypeGetChecked(MlirLocation loc, MlirType elementType); //===----------------------------------------------------------------------===// // Ranked / Unranked MemRef type. @@ -230,8 +229,8 @@ /// Same as "mlirMemRefTypeGet" but returns a nullptr-wrapping MlirType o /// illegal arguments, emitting appropriate diagnostics. MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeGetChecked( - MlirType elementType, intptr_t rank, const int64_t *shape, intptr_t numMaps, - MlirAffineMap const *affineMaps, unsigned memorySpace, MlirLocation loc); + MlirLocation loc, MlirType elementType, intptr_t rank, const int64_t *shape, + intptr_t numMaps, MlirAffineMap const *affineMaps, unsigned memorySpace); /// Creates a MemRef type with the given rank, shape, memory space and element /// type in the same context as the element type. The type has no affine maps, @@ -245,8 +244,8 @@ /// Same as "mlirMemRefTypeContiguousGet" but returns a nullptr wrapping /// MlirType on illegal arguments, emitting appropriate diagnostics. MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeContiguousGetChecked( - MlirType elementType, intptr_t rank, const int64_t *shape, - unsigned memorySpace, MlirLocation loc); + MlirLocation loc, MlirType elementType, intptr_t rank, const int64_t *shape, + unsigned memorySpace); /// Creates an Unranked MemRef type with the given element type and in the given /// memory space. The type is owned by the context of element type. @@ -256,7 +255,7 @@ /// Same as "mlirUnrankedMemRefTypeGet" but returns a nullptr wrapping /// MlirType on illegal arguments, emitting appropriate diagnostics. MLIR_CAPI_EXPORTED MlirType mlirUnrankedMemRefTypeGetChecked( - MlirType elementType, unsigned memorySpace, MlirLocation loc); + MlirLocation loc, MlirType elementType, unsigned memorySpace); /// Returns the number of affine layout maps in the given MemRef type. MLIR_CAPI_EXPORTED intptr_t mlirMemRefTypeGetNumAffineMaps(MlirType type); diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncTypes.td b/mlir/include/mlir/Dialect/Async/IR/AsyncTypes.td --- a/mlir/include/mlir/Dialect/Async/IR/AsyncTypes.td +++ b/mlir/include/mlir/Dialect/Async/IR/AsyncTypes.td @@ -43,9 +43,7 @@ let parameters = (ins "Type":$valueType); let builders = [ TypeBuilderWithInferredContext<(ins "Type":$valueType), [{ - return Base::get(valueType.getContext(), valueType); - }], [{ - return Base::getChecked($_loc, valueType); + return $_get(valueType.getContext(), valueType); }]> ]; let skipDefaultBuilders = 1; diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h @@ -68,6 +68,7 @@ public: /// Inherit base constructors. using Base::Base; + using Base::getChecked; /// Checks if the given type can be used inside an array type. static bool isValidElementType(Type type); @@ -75,8 +76,8 @@ /// Gets or creates an instance of LLVM dialect array type containing /// `numElements` of `elementType`, in the same context as `elementType`. static LLVMArrayType get(Type elementType, unsigned numElements); - static LLVMArrayType getChecked(Location loc, Type elementType, - unsigned numElements); + static LLVMArrayType getChecked(function_ref emitError, + Type elementType, unsigned numElements); /// Returns the element type of the array. Type getElementType(); @@ -85,9 +86,8 @@ unsigned getNumElements(); /// Verifies that the type about to be constructed is well-formed. - static LogicalResult verifyConstructionInvariants(Location loc, - Type elementType, - unsigned numElements); + static LogicalResult verify(function_ref emitError, + Type elementType, unsigned numElements); }; //===----------------------------------------------------------------------===// @@ -103,6 +103,7 @@ public: /// Inherit base constructors. using Base::Base; + using Base::getChecked; /// Checks if the given type can be used an argument in a function type. static bool isValidArgumentType(Type type); @@ -117,9 +118,9 @@ /// as the `result` type. static LLVMFunctionType get(Type result, ArrayRef arguments, bool isVarArg = false); - static LLVMFunctionType getChecked(Location loc, Type result, - ArrayRef arguments, - bool isVarArg = false); + static LLVMFunctionType + getChecked(function_ref emitError, Type result, + ArrayRef arguments, bool isVarArg = false); /// Returns the result type of the function. Type getReturnType(); @@ -135,9 +136,8 @@ ArrayRef params() { return getParams(); } /// Verifies that the type about to be constructed is well-formed. - static LogicalResult verifyConstructionInvariants(Location loc, Type result, - ArrayRef arguments, - bool); + static LogicalResult verify(function_ref emitError, + Type result, ArrayRef arguments, bool); }; //===----------------------------------------------------------------------===// @@ -152,6 +152,7 @@ public: /// Inherit base constructors. using Base::Base; + using Base::getChecked; /// Checks if the given type can have a pointer type pointing to it. static bool isValidElementType(Type type); @@ -160,8 +161,9 @@ /// object of `pointee` type in the given address space. The pointer type is /// created in the same context as `pointee`. static LLVMPointerType get(Type pointee, unsigned addressSpace = 0); - static LLVMPointerType getChecked(Location loc, Type pointee, - unsigned addressSpace = 0); + static LLVMPointerType + getChecked(function_ref emitError, Type pointee, + unsigned addressSpace = 0); /// Returns the pointed-to type. Type getElementType(); @@ -170,8 +172,8 @@ unsigned getAddressSpace(); /// Verifies that the type about to be constructed is well-formed. - static LogicalResult verifyConstructionInvariants(Location loc, Type pointee, - unsigned); + static LogicalResult verify(function_ref emitError, + Type pointee, unsigned); }; //===----------------------------------------------------------------------===// @@ -217,7 +219,9 @@ /// in the context. Instead, it will just return the existing struct, /// similarly to the rest of MLIR type ::get methods. static LLVMStructType getIdentified(MLIRContext *context, StringRef name); - static LLVMStructType getIdentifiedChecked(Location loc, StringRef name); + static LLVMStructType + getIdentifiedChecked(function_ref emitError, + MLIRContext *context, StringRef name); /// Gets a new identified struct with the given body. The body _cannot_ be /// changed later. If a struct with the given name already exists, renames @@ -231,8 +235,10 @@ /// context. static LLVMStructType getLiteral(MLIRContext *context, ArrayRef types, bool isPacked = false); - static LLVMStructType getLiteralChecked(Location loc, ArrayRef types, - bool isPacked = false); + static LLVMStructType + getLiteralChecked(function_ref emitError, + MLIRContext *context, ArrayRef types, + bool isPacked = false); /// Gets or creates an intentionally-opaque identified struct. Such a struct /// cannot have its body set. To create an opaque struct with a mutable body, @@ -241,7 +247,9 @@ /// already exists in the context. Instead, it will just return the existing /// struct, similarly to the rest of MLIR type ::get methods. static LLVMStructType getOpaque(StringRef name, MLIRContext *context); - static LLVMStructType getOpaqueChecked(Location loc, StringRef name); + static LLVMStructType + getOpaqueChecked(function_ref emitError, + MLIRContext *context, StringRef name); /// Set the body of an identified struct. Returns failure if the body could /// not be set, e.g. if the struct already has a body or if it was marked as @@ -270,9 +278,10 @@ ArrayRef getBody(); /// Verifies that the type about to be constructed is well-formed. - static LogicalResult verifyConstructionInvariants(Location, StringRef, bool); - static LogicalResult verifyConstructionInvariants(Location loc, - ArrayRef types, bool); + static LogicalResult verify(function_ref emitError, + StringRef, bool); + static LogicalResult verify(function_ref emitError, + ArrayRef types, bool); }; //===----------------------------------------------------------------------===// @@ -300,9 +309,8 @@ llvm::ElementCount getElementCount(); /// Verifies that the type about to be constructed is well-formed. - static LogicalResult verifyConstructionInvariants(Location loc, - Type elementType, - unsigned numElements); + static LogicalResult verify(function_ref emitError, + Type elementType, unsigned numElements); }; //===----------------------------------------------------------------------===// @@ -317,12 +325,14 @@ public: /// Inherit base constructor. using Base::Base; + using Base::getChecked; /// Gets or creates a fixed vector type containing `numElements` of /// `elementType` in the same context as `elementType`. static LLVMFixedVectorType get(Type elementType, unsigned numElements); - static LLVMFixedVectorType getChecked(Location loc, Type elementType, - unsigned numElements); + static LLVMFixedVectorType + getChecked(function_ref emitError, Type elementType, + unsigned numElements); /// Checks if the given type can be used in a vector type. This type supports /// only a subset of LLVM dialect types that don't have a built-in @@ -336,9 +346,8 @@ unsigned getNumElements(); /// Verifies that the type about to be constructed is well-formed. - static LogicalResult verifyConstructionInvariants(Location loc, - Type elementType, - unsigned numElements); + static LogicalResult verify(function_ref emitError, + Type elementType, unsigned numElements); }; //===----------------------------------------------------------------------===// @@ -354,12 +363,14 @@ public: /// Inherit base constructor. using Base::Base; + using Base::getChecked; /// Gets or creates a scalable vector type containing a non-zero multiple of /// `minNumElements` of `elementType` in the same context as `elementType`. static LLVMScalableVectorType get(Type elementType, unsigned minNumElements); - static LLVMScalableVectorType getChecked(Location loc, Type elementType, - unsigned minNumElements); + static LLVMScalableVectorType + getChecked(function_ref emitError, Type elementType, + unsigned minNumElements); /// Checks if the given type can be used in a vector type. static bool isValidElementType(Type type); @@ -373,9 +384,8 @@ unsigned getMinNumElements(); /// Verifies that the type about to be constructed is well-formed. - static LogicalResult verifyConstructionInvariants(Location loc, - Type elementType, - unsigned minNumElements); + static LogicalResult verify(function_ref emitError, + Type elementType, unsigned minNumElements); }; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Quant/QuantTypes.h b/mlir/include/mlir/Dialect/Quant/QuantTypes.h --- a/mlir/include/mlir/Dialect/Quant/QuantTypes.h +++ b/mlir/include/mlir/Dialect/Quant/QuantTypes.h @@ -57,10 +57,10 @@ /// The maximum number of bits supported for storage types. static constexpr unsigned MaxStorageBits = 32; - static LogicalResult - verifyConstructionInvariants(Location loc, unsigned flags, Type storageType, - Type expressedType, int64_t storageTypeMin, - int64_t storageTypeMax); + static LogicalResult verify(function_ref emitError, + unsigned flags, Type storageType, + Type expressedType, int64_t storageTypeMin, + int64_t storageTypeMax); /// Support method to enable LLVM-style type casting. static bool classof(Type type); @@ -199,6 +199,7 @@ detail::AnyQuantizedTypeStorage> { public: using Base::Base; + using Base::getChecked; /// Gets an instance of the type with all parameters specified but not /// checked. @@ -208,15 +209,16 @@ /// Gets an instance of the type with all specified parameters checked. /// Returns a nullptr convertible type on failure. - static AnyQuantizedType getChecked(unsigned flags, Type storageType, - Type expressedType, int64_t storageTypeMin, - int64_t storageTypeMax, Location location); + static AnyQuantizedType + getChecked(function_ref emitError, unsigned flags, + Type storageType, Type expressedType, int64_t storageTypeMin, + int64_t storageTypeMax); /// Verifies construction invariants and issues errors/warnings. - static LogicalResult - verifyConstructionInvariants(Location loc, unsigned flags, Type storageType, - Type expressedType, int64_t storageTypeMin, - int64_t storageTypeMax); + static LogicalResult verify(function_ref emitError, + unsigned flags, Type storageType, + Type expressedType, int64_t storageTypeMin, + int64_t storageTypeMax); }; /// Represents a family of uniform, quantized types. @@ -256,6 +258,7 @@ detail::UniformQuantizedTypeStorage> { public: using Base::Base; + using Base::getChecked; /// Gets an instance of the type with all parameters specified but not /// checked. @@ -267,16 +270,16 @@ /// Gets an instance of the type with all specified parameters checked. /// Returns a nullptr convertible type on failure. static UniformQuantizedType - getChecked(unsigned flags, Type storageType, Type expressedType, double scale, - int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax, - Location location); + getChecked(function_ref emitError, unsigned flags, + Type storageType, Type expressedType, double scale, + int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax); /// Verifies construction invariants and issues errors/warnings. - static LogicalResult - verifyConstructionInvariants(Location loc, unsigned flags, Type storageType, - Type expressedType, double scale, - int64_t zeroPoint, int64_t storageTypeMin, - int64_t storageTypeMax); + static LogicalResult verify(function_ref emitError, + unsigned flags, Type storageType, + Type expressedType, double scale, + int64_t zeroPoint, int64_t storageTypeMin, + int64_t storageTypeMax); /// Gets the scale term. The scale designates the difference between the real /// values corresponding to consecutive quantized values differing by 1. @@ -313,6 +316,7 @@ detail::UniformQuantizedPerAxisTypeStorage> { public: using Base::Base; + using Base::getChecked; /// Gets an instance of the type with all parameters specified but not /// checked. @@ -325,18 +329,18 @@ /// Gets an instance of the type with all specified parameters checked. /// Returns a nullptr convertible type on failure. static UniformQuantizedPerAxisType - getChecked(unsigned flags, Type storageType, Type expressedType, - ArrayRef scales, ArrayRef zeroPoints, - int32_t quantizedDimension, int64_t storageTypeMin, - int64_t storageTypeMax, Location location); + getChecked(function_ref emitError, unsigned flags, + Type storageType, Type expressedType, ArrayRef scales, + ArrayRef zeroPoints, int32_t quantizedDimension, + int64_t storageTypeMin, int64_t storageTypeMax); /// Verifies construction invariants and issues errors/warnings. - static LogicalResult - verifyConstructionInvariants(Location loc, unsigned flags, Type storageType, - Type expressedType, ArrayRef scales, - ArrayRef zeroPoints, - int32_t quantizedDimension, - int64_t storageTypeMin, int64_t storageTypeMax); + static LogicalResult verify(function_ref emitError, + unsigned flags, Type storageType, + Type expressedType, ArrayRef scales, + ArrayRef zeroPoints, + int32_t quantizedDimension, + int64_t storageTypeMin, int64_t storageTypeMax); /// Gets the quantization scales. The scales designate the difference between /// the real values corresponding to consecutive quantized values differing @@ -381,6 +385,7 @@ detail::CalibratedQuantizedTypeStorage> { public: using Base::Base; + using Base::getChecked; /// Gets an instance of the type with all parameters specified but not /// checked. @@ -389,13 +394,13 @@ /// Gets an instance of the type with all specified parameters checked. /// Returns a nullptr convertible type on failure. - static CalibratedQuantizedType getChecked(Type expressedType, double min, - double max, Location location); + static CalibratedQuantizedType + getChecked(function_ref emitError, Type expressedType, + double min, double max); /// Verifies construction invariants and issues errors/warnings. - static LogicalResult verifyConstructionInvariants(Location loc, - Type expressedType, - double min, double max); + static LogicalResult verify(function_ref emitError, + Type expressedType, double min, double max); double getMin() const; double getMax() const; }; diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.h @@ -69,10 +69,9 @@ /// Returns `spirv::StorageClass`. Optional getStorageClass(); - static LogicalResult verifyConstructionInvariants(Location loc, - IntegerAttr descriptorSet, - IntegerAttr binding, - IntegerAttr storageClass); + static LogicalResult verify(function_ref emitError, + IntegerAttr descriptorSet, IntegerAttr binding, + IntegerAttr storageClass); }; /// An attribute that specifies the SPIR-V (version, capabilities, extensions) @@ -120,10 +119,9 @@ /// Returns the capabilities as an integer array attribute. ArrayAttr getCapabilitiesAttr(); - static LogicalResult verifyConstructionInvariants(Location loc, - IntegerAttr version, - ArrayAttr capabilities, - ArrayAttr extensions); + static LogicalResult verify(function_ref emitError, + IntegerAttr version, ArrayAttr capabilities, + ArrayAttr extensions); }; /// An attribute that specifies the target version, allowed extensions and @@ -174,10 +172,10 @@ /// Returns the target resource limits. ResourceLimitsAttr getResourceLimits() const; - static LogicalResult - verifyConstructionInvariants(Location loc, VerCapExtAttr triple, - Vendor vendorID, DeviceType deviceType, - uint32_t deviceID, DictionaryAttr limits); + static LogicalResult verify(function_ref emitError, + VerCapExtAttr triple, Vendor vendorID, + DeviceType deviceType, uint32_t deviceID, + DictionaryAttr limits); }; } // namespace spirv } // namespace mlir diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h @@ -243,10 +243,11 @@ static SampledImageType get(Type imageType); - static SampledImageType getChecked(Type imageType, Location location); + static SampledImageType + getChecked(function_ref emitError, Type imageType); - static LogicalResult verifyConstructionInvariants(Location Loc, - Type imageType); + static LogicalResult verify(function_ref emitError, + Type imageType); Type getImageType() const; @@ -426,12 +427,11 @@ static MatrixType get(Type columnType, uint32_t columnCount); - static MatrixType getChecked(Type columnType, uint32_t columnCount, - Location location); + static MatrixType getChecked(function_ref emitError, + Type columnType, uint32_t columnCount); - static LogicalResult verifyConstructionInvariants(Location loc, - Type columnType, - uint32_t columnCount); + static LogicalResult verify(function_ref emitError, + Type columnType, uint32_t columnCount); /// Returns true if the matrix elements are vectors of float elements. static bool isValidColumnType(Type columnType); diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h --- a/mlir/include/mlir/IR/BuiltinAttributes.h +++ b/mlir/include/mlir/IR/BuiltinAttributes.h @@ -182,17 +182,20 @@ detail::FloatAttributeStorage> { public: using Base::Base; + using Base::getChecked; using ValueType = APFloat; /// Return a float attribute for the specified value in the specified type. /// These methods should only be used for simple constant values, e.g 1.0/2.0, /// that are known-valid both as host double and the 'type' format. static FloatAttr get(Type type, double value); - static FloatAttr getChecked(Type type, double value, Location loc); + static FloatAttr getChecked(function_ref emitError, + Type type, double value); /// Return a float attribute for the specified value in the specified type. static FloatAttr get(Type type, const APFloat &value); - static FloatAttr getChecked(Type type, const APFloat &value, Location loc); + static FloatAttr getChecked(function_ref emitError, + Type type, const APFloat &value); APFloat getValue() const; @@ -202,10 +205,10 @@ static double getValueAsDouble(APFloat val); /// Verify the construction invariants for a double value. - static LogicalResult verifyConstructionInvariants(Location loc, Type type, - double value); - static LogicalResult verifyConstructionInvariants(Location loc, Type type, - const APFloat &value); + static LogicalResult verify(function_ref emitError, + Type type, double value); + static LogicalResult verify(function_ref emitError, + Type type, const APFloat &value); }; //===----------------------------------------------------------------------===// @@ -234,10 +237,10 @@ /// an unsigned integer. uint64_t getUInt() const; - static LogicalResult verifyConstructionInvariants(Location loc, Type type, - int64_t value); - static LogicalResult verifyConstructionInvariants(Location loc, Type type, - const APInt &value); + static LogicalResult verify(function_ref emitError, + Type type, int64_t value); + static LogicalResult verify(function_ref emitError, + Type type, const APInt &value); }; //===----------------------------------------------------------------------===// @@ -290,6 +293,7 @@ detail::OpaqueAttributeStorage> { public: using Base::Base; + using Base::getChecked; /// Get or create a new OpaqueAttr with the provided dialect and string data. static OpaqueAttr get(MLIRContext *context, Identifier dialect, @@ -298,8 +302,9 @@ /// Get or create a new OpaqueAttr with the provided dialect and string data. /// If the given identifier is not a valid namespace for a dialect, then a /// null attribute is returned. - static OpaqueAttr getChecked(Identifier dialect, StringRef attrData, - Type type, Location location); + static OpaqueAttr getChecked(function_ref emitError, + Identifier dialect, StringRef attrData, + Type type); /// Returns the dialect namespace of the opaque attribute. Identifier getDialectNamespace() const; @@ -308,10 +313,9 @@ StringRef getAttrData() const; /// Verify the construction of an opaque attribute. - static LogicalResult verifyConstructionInvariants(Location loc, - Identifier dialect, - StringRef attrData, - Type type); + static LogicalResult verify(function_ref emitError, + Identifier dialect, StringRef attrData, + Type type); }; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -167,22 +167,21 @@ : public Type::TypeBase { public: using Base::Base; + using Base::getChecked; /// Get or create a new VectorType of the provided shape and element type. /// Assumes the arguments define a well-formed VectorType. static VectorType get(ArrayRef shape, Type elementType); - /// Get or create a new VectorType of the provided shape and element type - /// declared at the given, potentially unknown, location. If the VectorType - /// defined by the arguments would be ill-formed, emit errors and return - /// nullptr-wrapping type. - static VectorType getChecked(Location location, ArrayRef shape, - Type elementType); + /// Get or create a new VectorType of the provided shape and element type. If + /// the VectorType defined by the arguments would be ill-formed, an error is + /// emitted to `emitError` and a null type is returned. + static VectorType getChecked(function_ref emitError, + ArrayRef shape, Type elementType); /// Verify the construction of a vector type. - static LogicalResult verifyConstructionInvariants(Location loc, - ArrayRef shape, - Type elementType); + static LogicalResult verify(function_ref emitError, + ArrayRef shape, Type elementType); /// Returns true of the given type can be used as an element of a vector type. /// In particular, vectors can consist of integer or float primitives. @@ -226,22 +225,23 @@ detail::RankedTensorTypeStorage> { public: using Base::Base; + using Base::getChecked; /// Get or create a new RankedTensorType of the provided shape and element /// type. Assumes the arguments define a well-formed type. static RankedTensorType get(ArrayRef shape, Type elementType); /// Get or create a new RankedTensorType of the provided shape and element - /// type declared at the given, potentially unknown, location. If the - /// RankedTensorType defined by the arguments would be ill-formed, emit errors - /// and return a nullptr-wrapping type. - static RankedTensorType getChecked(Location location, ArrayRef shape, - Type elementType); + /// type. If the RankedTensorType defined by the arguments would be + /// ill-formed, an error is emitted to `emitError` and a null type is + /// returned. + static RankedTensorType + getChecked(function_ref emitError, + ArrayRef shape, Type elementType); /// Verify the construction of a ranked tensor type. - static LogicalResult verifyConstructionInvariants(Location loc, - ArrayRef shape, - Type elementType); + static LogicalResult verify(function_ref emitError, + ArrayRef shape, Type elementType); ArrayRef getShape() const; }; @@ -256,20 +256,22 @@ detail::UnrankedTensorTypeStorage> { public: using Base::Base; + using Base::getChecked; /// Get or create a new UnrankedTensorType of the provided shape and element /// type. Assumes the arguments define a well-formed type. static UnrankedTensorType get(Type elementType); /// Get or create a new UnrankedTensorType of the provided shape and element - /// type declared at the given, potentially unknown, location. If the - /// UnrankedTensorType defined by the arguments would be ill-formed, emit - /// errors and return a nullptr-wrapping type. - static UnrankedTensorType getChecked(Location location, Type elementType); + /// type. If the RankedTensorType defined by the arguments would be + /// ill-formed, an error is emitted to `emitError` and a null type is + /// returned. + static UnrankedTensorType + getChecked(function_ref emitError, Type elementType); /// Verify the construction of a unranked tensor type. - static LogicalResult verifyConstructionInvariants(Location loc, - Type elementType); + static LogicalResult verify(function_ref emitError, + Type elementType); ArrayRef getShape() const { return llvm::None; } }; @@ -351,6 +353,7 @@ }; using Base::Base; + using Base::getChecked; /// Get or create a new MemRefType based on shape, element type, affine /// map composition, and memory space. Assumes the arguments define a @@ -361,13 +364,11 @@ unsigned memorySpace = 0); /// Get or create a new MemRefType based on shape, element type, affine - /// map composition, and memory space declared at the given location. - /// If the location is unknown, the last argument should be an instance of - /// UnknownLoc. If the MemRefType defined by the arguments would be - /// ill-formed, emits errors (to the handler registered with the context or to - /// the error stream) and returns nullptr. - static MemRefType getChecked(Location location, ArrayRef shape, - Type elementType, + /// map composition, and memory space. If the MemRefType defined by the + /// arguments would be ill-formed, an error is emitted to `emitError` and a + /// null type is returned. + static MemRefType getChecked(function_ref emitError, + ArrayRef shape, Type elementType, ArrayRef affineMapComposition, unsigned memorySpace); @@ -386,11 +387,11 @@ private: /// Get or create a new MemRefType defined by the arguments. If the resulting - /// type would be ill-formed, return nullptr. If the location is provided, - /// emit detailed error messages. + /// type would be ill-formed, return nullptr. static MemRefType getImpl(ArrayRef shape, Type elementType, ArrayRef affineMapComposition, - unsigned memorySpace, Optional location); + unsigned memorySpace, + function_ref emitError); using Base::getImpl; }; @@ -404,22 +405,23 @@ detail::UnrankedMemRefTypeStorage> { public: using Base::Base; + using Base::getChecked; /// Get or create a new UnrankedMemRefType of the provided element /// type and memory space static UnrankedMemRefType get(Type elementType, unsigned memorySpace); /// Get or create a new UnrankedMemRefType of the provided element - /// type and memory space declared at the given, potentially unknown, - /// location. If the UnrankedMemRefType defined by the arguments would be - /// ill-formed, emit errors and return a nullptr-wrapping type. - static UnrankedMemRefType getChecked(Location location, Type elementType, - unsigned memorySpace); + /// type and memory space. If the UnrankedMemRefType defined by the arguments + /// would be ill-formed, an error is emitted to `emitError` and a null type is + /// returned. + static UnrankedMemRefType + getChecked(function_ref emitError, Type elementType, + unsigned memorySpace); /// Verify the construction of a unranked memref type. - static LogicalResult verifyConstructionInvariants(Location loc, - Type elementType, - unsigned memorySpace); + static LogicalResult verify(function_ref emitError, + Type elementType, unsigned memorySpace); ArrayRef getShape() const { return llvm::None; } }; diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -52,13 +52,11 @@ let parameters = (ins "Type":$elementType); let builders = [ TypeBuilderWithInferredContext<(ins "Type":$elementType), [{ - return Base::get(elementType.getContext(), elementType); - }], [{ - return Base::getChecked($_loc, elementType); + return $_get(elementType.getContext(), elementType); }]> ]; let skipDefaultBuilders = 1; - let genVerifyInvariantsDecl = 1; + let genVerifyDecl = 1; } //===----------------------------------------------------------------------===// @@ -137,7 +135,7 @@ let parameters = (ins "ArrayRef":$inputs, "ArrayRef":$results); let builders = [ TypeBuilder<(ins CArg<"TypeRange">:$inputs, CArg<"TypeRange">:$results), [{ - return Base::get($_ctxt, inputs, results); + return $_get($_ctxt, inputs, results); }]> ]; let skipDefaultBuilders = 1; @@ -225,7 +223,7 @@ // memory. let genStorageClass = 0; let skipDefaultBuilders = 1; - let genVerifyInvariantsDecl = 1; + let genVerifyDecl = 1; let extraClassDeclaration = [{ /// Signedness semantics. enum SignednessSemantics : uint32_t { @@ -295,7 +293,7 @@ "Identifier":$dialectNamespace, StringRefParameter<"">:$typeData ); - let genVerifyInvariantsDecl = 1; + let genVerifyDecl = 1; } //===----------------------------------------------------------------------===// @@ -334,10 +332,10 @@ let parameters = (ins "ArrayRef":$types); let builders = [ TypeBuilder<(ins "TypeRange":$elementTypes), [{ - return Base::get($_ctxt, elementTypes); + return $_get($_ctxt, elementTypes); }]>, TypeBuilder<(ins), [{ - return Base::get($_ctxt, TypeRange()); + return $_get($_ctxt, TypeRange()); }]> ]; let skipDefaultBuilders = 1; diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -2492,15 +2492,21 @@ // // If an empty string is passed in for `body`, then *only* the builder // declaration will be generated; this provides a way to define complicated -// builders entirely in C++. +// builders entirely in C++. If a `body` string is provided, the `Base::get` +// method should be invoked using `$_get`, e.g.: +// +// ``` +// TypeBuilder<(ins "int":$integerArg, CArg<"float", "3.0f">:$floatArg), [{ +// return $_get($_ctxt, integerArg, floatArg); +// }]> +// ``` +// +// This is necessary because the `body` is also used to generate `getChecked` +// methods, which have a different underlying `Base::get*` call. // -// `checkedBody` is similar to `body`, but is the code block used when -// generating a `getChecked` method. -class TypeBuilder { +class TypeBuilder { dag dagParams = parameters; code body = bodyCode; - code checkedBody = checkedBodyCode; // The context parameter can be inferred from one of the other parameters and // is not implicitly added to the parameter list. @@ -2510,10 +2516,8 @@ // A class of TypeBuilder that is able to infer the MLIRContext parameter from // one of the other builder parameters. Instances of this builder do not have // `MLIRContext *` implicitly added to the parameter list. -class TypeBuilderWithInferredContext +class TypeBuilderWithInferredContext : TypeBuilder { - code checkedBody = checkedBodyCode; let hasInferredContextParam = 1; } @@ -2590,9 +2594,8 @@ // Avoid generating default get/getChecked functions. Custom get methods must // be provided. bit skipDefaultBuilders = 0; - // Generate the verifyConstructionInvariants declaration and getChecked - // method. - bit genVerifyInvariantsDecl = 0; + // Generate the verify and getChecked methods. + bit genVerifyDecl = 0; // Extra code to include in the class declaration. code extraClassDeclaration = [{}]; diff --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h --- a/mlir/include/mlir/IR/StorageUniquerSupport.h +++ b/mlir/include/mlir/IR/StorageUniquerSupport.h @@ -17,16 +17,21 @@ #include "mlir/Support/LogicalResult.h" #include "mlir/Support/StorageUniquer.h" #include "mlir/Support/TypeID.h" +#include "llvm/ADT/FunctionExtras.h" namespace mlir { -class AttributeStorage; +class InFlightDiagnostic; +class Location; class MLIRContext; namespace detail { -/// Utility method to generate a raw default location for use when checking the -/// construction invariants of a storage object. This is defined out-of-line to -/// avoid the need to include Location.h. -const AttributeStorage *generateUnknownStorageLocation(MLIRContext *ctx); +/// Utility method to generate a callback that can be used to generate a +/// diagnostic when checking the construction invariants of a storage object. +/// This is defined out-of-line to avoid the need to include Location.h. +llvm::unique_function +getDefaultDiagnosticEmitFn(MLIRContext *ctx); +llvm::unique_function +getDefaultDiagnosticEmitFn(const Location &loc); //===----------------------------------------------------------------------===// // StorageUserTraitBase @@ -88,20 +93,30 @@ template static ConcreteT get(MLIRContext *ctx, Args... args) { // Ensure that the invariants are correct for construction. - assert(succeeded(ConcreteT::verifyConstructionInvariants( - generateUnknownStorageLocation(ctx), args...))); + assert( + succeeded(ConcreteT::verify(getDefaultDiagnosticEmitFn(ctx), args...))); return UniquerT::template get(ctx, args...); } /// Get or create a new ConcreteT instance within the ctx, defined at /// the given, potentially unknown, location. If the arguments provided are - /// invalid then emit errors and return a null object. - template - static ConcreteT getChecked(LocationT loc, Args... args) { + /// invalid, errors are emitted using the provided location and a null object + /// is returned. + template + static ConcreteT getChecked(const Location &loc, Args... args) { + return ConcreteT::getChecked(getDefaultDiagnosticEmitFn(loc), args...); + } + + /// Get or create a new ConcreteT instance within the ctx. If the arguments + /// provided are invalid, errors are emitted using the provided `emitError` + /// and a null object is returned. + template + static ConcreteT getChecked(function_ref emitErrorFn, + MLIRContext *ctx, Args... args) { // If the construction invariants fail then we return a null attribute. - if (failed(ConcreteT::verifyConstructionInvariants(loc, args...))) + if (failed(ConcreteT::verify(emitErrorFn, args...))) return ConcreteT(); - return UniquerT::template get(loc.getContext(), args...); + return UniquerT::template get(ctx, args...); } /// Get an instance of the concrete type from a void pointer. @@ -120,7 +135,7 @@ /// Default implementation that just returns success. template - static LogicalResult verifyConstructionInvariants(Args... args) { + static LogicalResult verify(Args... args) { return success(); } diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h --- a/mlir/include/mlir/IR/Types.h +++ b/mlir/include/mlir/IR/Types.h @@ -32,8 +32,9 @@ /// Derived type classes are expected to implement several required /// implementation hooks: /// * Optional: -/// - static LogicalResult verifyConstructionInvariants(Location loc, -/// Args... args) +/// - static LogicalResult verify( +/// function_ref emitError, +/// Args... args) /// * This method is invoked when calling the 'TypeBase::get/getChecked' /// methods to ensure that the arguments passed in are valid to construct /// a type instance with. diff --git a/mlir/include/mlir/TableGen/TypeDef.h b/mlir/include/mlir/TableGen/TypeDef.h --- a/mlir/include/mlir/TableGen/TypeDef.h +++ b/mlir/include/mlir/TableGen/TypeDef.h @@ -36,10 +36,6 @@ public: using Builder::Builder; - /// Return an optional code body used for the `getChecked` variant of this - /// builder. - Optional getCheckedBody() const; - /// Returns true if this builder is able to infer the MLIRContext parameter. bool hasInferredContextParameter() const; }; @@ -106,9 +102,9 @@ // generated. bool genAccessors() const; - // Return true if we need to generate the verifyConstructionInvariants - // declaration and getChecked method. - bool genVerifyInvariantsDecl() const; + // Return true if we need to generate the verify declaration and getChecked + // method. + bool genVerifyDecl() const; // Returns the dialects extra class declaration code. Optional getExtraDecls() const; diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -1868,7 +1868,7 @@ c.def_static( "get", [](PyType &type, double value, DefaultingPyLocation loc) { - MlirAttribute attr = mlirFloatAttrDoubleGetChecked(type, value, loc); + MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value); // TODO: Rework error reporting once diagnostic engine is exposed // in C API. if (mlirAttributeIsNull(attr)) { @@ -2765,8 +2765,8 @@ "get", [](std::vector shape, PyType &elementType, DefaultingPyLocation loc) { - MlirType t = mlirVectorTypeGetChecked(shape.size(), shape.data(), - elementType, loc); + MlirType t = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(), + elementType); // TODO: Rework error reporting once diagnostic engine is exposed // in C API. if (mlirTypeIsNull(t)) { @@ -2797,7 +2797,7 @@ [](std::vector shape, PyType &elementType, DefaultingPyLocation loc) { MlirType t = mlirRankedTensorTypeGetChecked( - shape.size(), shape.data(), elementType, loc); + loc, shape.size(), shape.data(), elementType); // TODO: Rework error reporting once diagnostic engine is exposed // in C API. if (mlirTypeIsNull(t)) { @@ -2828,7 +2828,7 @@ c.def_static( "get", [](PyType &elementType, DefaultingPyLocation loc) { - MlirType t = mlirUnrankedTensorTypeGetChecked(elementType, loc); + MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType); // TODO: Rework error reporting once diagnostic engine is exposed // in C API. if (mlirTypeIsNull(t)) { @@ -2869,9 +2869,9 @@ for (PyAffineMap &map : layout) maps.push_back(map); - MlirType t = mlirMemRefTypeGetChecked(elementType, shape.size(), + MlirType t = mlirMemRefTypeGetChecked(loc, elementType, shape.size(), shape.data(), maps.size(), - maps.data(), memorySpace, loc); + maps.data(), memorySpace); // TODO: Rework error reporting once diagnostic engine is exposed // in C API. if (mlirTypeIsNull(t)) { @@ -2948,7 +2948,7 @@ [](PyType &elementType, unsigned memorySpace, DefaultingPyLocation loc) { MlirType t = - mlirUnrankedMemRefTypeGetChecked(elementType, memorySpace, loc); + mlirUnrankedMemRefTypeGetChecked(loc, elementType, memorySpace); // TODO: Rework error reporting once diagnostic engine is exposed // in C API. if (mlirTypeIsNull(t)) { diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -103,9 +103,9 @@ return wrap(FloatAttr::get(unwrap(type), value)); } -MlirAttribute mlirFloatAttrDoubleGetChecked(MlirType type, double value, - MlirLocation loc) { - return wrap(FloatAttr::getChecked(unwrap(type), value, unwrap(loc))); +MlirAttribute mlirFloatAttrDoubleGetChecked(MlirLocation loc, MlirType type, + double value) { + return wrap(FloatAttr::getChecked(unwrap(loc), unwrap(type), value)); } double mlirFloatAttrGetValueDouble(MlirAttribute attr) { diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -169,8 +169,8 @@ unwrap(elementType))); } -MlirType mlirVectorTypeGetChecked(intptr_t rank, const int64_t *shape, - MlirType elementType, MlirLocation loc) { +MlirType mlirVectorTypeGetChecked(MlirLocation loc, intptr_t rank, + const int64_t *shape, MlirType elementType) { return wrap(VectorType::getChecked( unwrap(loc), llvm::makeArrayRef(shape, static_cast(rank)), unwrap(elementType))); @@ -197,9 +197,9 @@ unwrap(elementType))); } -MlirType mlirRankedTensorTypeGetChecked(intptr_t rank, const int64_t *shape, - MlirType elementType, - MlirLocation loc) { +MlirType mlirRankedTensorTypeGetChecked(MlirLocation loc, intptr_t rank, + const int64_t *shape, + MlirType elementType) { return wrap(RankedTensorType::getChecked( unwrap(loc), llvm::makeArrayRef(shape, static_cast(rank)), unwrap(elementType))); @@ -209,8 +209,8 @@ return wrap(UnrankedTensorType::get(unwrap(elementType))); } -MlirType mlirUnrankedTensorTypeGetChecked(MlirType elementType, - MlirLocation loc) { +MlirType mlirUnrankedTensorTypeGetChecked(MlirLocation loc, + MlirType elementType) { return wrap(UnrankedTensorType::getChecked(unwrap(loc), unwrap(elementType))); } @@ -231,10 +231,11 @@ unwrap(elementType), maps, memorySpace)); } -MlirType mlirMemRefTypeGetChecked(MlirType elementType, intptr_t rank, - const int64_t *shape, intptr_t numMaps, +MlirType mlirMemRefTypeGetChecked(MlirLocation loc, MlirType elementType, + intptr_t rank, const int64_t *shape, + intptr_t numMaps, MlirAffineMap const *affineMaps, - unsigned memorySpace, MlirLocation loc) { + unsigned memorySpace) { SmallVector maps; (void)unwrapList(numMaps, affineMaps, maps); return wrap(MemRefType::getChecked( @@ -250,10 +251,10 @@ unwrap(elementType), llvm::None, memorySpace)); } -MlirType mlirMemRefTypeContiguousGetChecked(MlirType elementType, intptr_t rank, +MlirType mlirMemRefTypeContiguousGetChecked(MlirLocation loc, + MlirType elementType, intptr_t rank, const int64_t *shape, - unsigned memorySpace, - MlirLocation loc) { + unsigned memorySpace) { return wrap(MemRefType::getChecked( unwrap(loc), llvm::makeArrayRef(shape, static_cast(rank)), unwrap(elementType), llvm::None, memorySpace)); @@ -280,9 +281,9 @@ return wrap(UnrankedMemRefType::get(unwrap(elementType), memorySpace)); } -MlirType mlirUnrankedMemRefTypeGetChecked(MlirType elementType, - unsigned memorySpace, - MlirLocation loc) { +MlirType mlirUnrankedMemRefTypeGetChecked(MlirLocation loc, + MlirType elementType, + unsigned memorySpace) { return wrap(UnrankedMemRefType::getChecked(unwrap(loc), unwrap(elementType), memorySpace)); } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp @@ -187,7 +187,7 @@ // Function type without arguments. if (succeeded(parser.parseOptionalRParen())) { if (succeeded(parser.parseGreater())) - return LLVMFunctionType::getChecked(loc, returnType, {}, + return LLVMFunctionType::getChecked(loc, returnType, llvm::None, /*isVarArg=*/false); return LLVMFunctionType(); } @@ -345,7 +345,8 @@ if (knownStructNames.count(name)) { if (failed(parser.parseGreater())) return LLVMStructType(); - return LLVMStructType::getIdentifiedChecked(loc, name); + return LLVMStructType::getIdentifiedChecked( + [loc] { return emitError(loc); }, loc.getContext(), name); } if (failed(parser.parseComma())) return LLVMStructType(); @@ -359,7 +360,8 @@ LLVMStructType(); if (failed(parser.parseGreater())) return LLVMStructType(); - auto type = LLVMStructType::getOpaqueChecked(loc, name); + auto type = LLVMStructType::getOpaqueChecked( + [loc] { return emitError(loc); }, loc.getContext(), name); if (!type.isOpaque()) { parser.emitError(kwLoc, "redeclaring defined struct as opaque"); return LLVMStructType(); @@ -377,8 +379,10 @@ if (failed(parser.parseGreater())) return LLVMStructType(); if (!isIdentified) - return LLVMStructType::getLiteralChecked(loc, {}, isPacked); - auto type = LLVMStructType::getIdentifiedChecked(loc, name); + return LLVMStructType::getLiteralChecked([loc] { return emitError(loc); }, + loc.getContext(), {}, isPacked); + auto type = LLVMStructType::getIdentifiedChecked( + [loc] { return emitError(loc); }, loc.getContext(), name); return trySetStructBody(type, {}, isPacked, parser, kwLoc); } @@ -402,8 +406,10 @@ // Construct the struct with body. if (!isIdentified) - return LLVMStructType::getLiteralChecked(loc, subtypes, isPacked); - auto type = LLVMStructType::getIdentifiedChecked(loc, name); + return LLVMStructType::getLiteralChecked( + [loc] { return emitError(loc); }, loc.getContext(), subtypes, isPacked); + auto type = LLVMStructType::getIdentifiedChecked( + [loc] { return emitError(loc); }, loc.getContext(), name); return trySetStructBody(type, subtypes, isPacked, parser, subtypesLoc); } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -39,10 +39,12 @@ return Base::get(elementType.getContext(), elementType, numElements); } -LLVMArrayType LLVMArrayType::getChecked(Location loc, Type elementType, - unsigned numElements) { +LLVMArrayType +LLVMArrayType::getChecked(function_ref emitError, + Type elementType, unsigned numElements) { assert(elementType && "expected non-null subtype"); - return Base::getChecked(loc, elementType, numElements); + return Base::getChecked(emitError, elementType.getContext(), elementType, + numElements); } Type LLVMArrayType::getElementType() { return getImpl()->elementType; } @@ -50,10 +52,10 @@ unsigned LLVMArrayType::getNumElements() { return getImpl()->numElements; } LogicalResult -LLVMArrayType::verifyConstructionInvariants(Location loc, Type elementType, - unsigned numElements) { +LLVMArrayType::verify(function_ref emitError, + Type elementType, unsigned numElements) { if (!isValidElementType(elementType)) - return emitError(loc, "invalid array element type: ") << elementType; + return emitError() << "invalid array element type: " << elementType; return success(); } @@ -75,11 +77,13 @@ return Base::get(result.getContext(), result, arguments, isVarArg); } -LLVMFunctionType LLVMFunctionType::getChecked(Location loc, Type result, - ArrayRef arguments, - bool isVarArg) { +LLVMFunctionType +LLVMFunctionType::getChecked(function_ref emitError, + Type result, ArrayRef arguments, + bool isVarArg) { assert(result && "expected non-null result"); - return Base::getChecked(loc, result, arguments, isVarArg); + return Base::getChecked(emitError, result.getContext(), result, arguments, + isVarArg); } Type LLVMFunctionType::getReturnType() { return getImpl()->getReturnType(); } @@ -99,14 +103,14 @@ } LogicalResult -LLVMFunctionType::verifyConstructionInvariants(Location loc, Type result, - ArrayRef arguments, bool) { +LLVMFunctionType::verify(function_ref emitError, + Type result, ArrayRef arguments, bool) { if (!isValidResultType(result)) - return emitError(loc, "invalid function result type: ") << result; + return emitError() << "invalid function result type: " << result; for (Type arg : arguments) if (!isValidArgumentType(arg)) - return emitError(loc, "invalid function argument type: ") << arg; + return emitError() << "invalid function argument type: " << arg; return success(); } @@ -125,20 +129,22 @@ return Base::get(pointee.getContext(), pointee, addressSpace); } -LLVMPointerType LLVMPointerType::getChecked(Location loc, Type pointee, - unsigned addressSpace) { - return Base::getChecked(loc, pointee, addressSpace); +LLVMPointerType +LLVMPointerType::getChecked(function_ref emitError, + Type pointee, unsigned addressSpace) { + return Base::getChecked(emitError, pointee.getContext(), pointee, + addressSpace); } Type LLVMPointerType::getElementType() { return getImpl()->pointeeType; } unsigned LLVMPointerType::getAddressSpace() { return getImpl()->addressSpace; } -LogicalResult LLVMPointerType::verifyConstructionInvariants(Location loc, - Type pointee, - unsigned) { +LogicalResult +LLVMPointerType::verify(function_ref emitError, + Type pointee, unsigned) { if (!isValidElementType(pointee)) - return emitError(loc, "invalid pointer element type: ") << pointee; + return emitError() << "invalid pointer element type: " << pointee; return success(); } @@ -156,9 +162,10 @@ return Base::get(context, name, /*opaque=*/false); } -LLVMStructType LLVMStructType::getIdentifiedChecked(Location loc, - StringRef name) { - return Base::getChecked(loc, name, /*opaque=*/false); +LLVMStructType LLVMStructType::getIdentifiedChecked( + function_ref emitError, MLIRContext *context, + StringRef name) { + return Base::getChecked(emitError, context, name, /*opaque=*/false); } LLVMStructType LLVMStructType::getNewIdentified(MLIRContext *context, @@ -183,18 +190,21 @@ return Base::get(context, types, isPacked); } -LLVMStructType LLVMStructType::getLiteralChecked(Location loc, - ArrayRef types, - bool isPacked) { - return Base::getChecked(loc, types, isPacked); +LLVMStructType +LLVMStructType::getLiteralChecked(function_ref emitError, + MLIRContext *context, ArrayRef types, + bool isPacked) { + return Base::getChecked(emitError, context, types, isPacked); } LLVMStructType LLVMStructType::getOpaque(StringRef name, MLIRContext *context) { return Base::get(context, name, /*opaque=*/true); } -LLVMStructType LLVMStructType::getOpaqueChecked(Location loc, StringRef name) { - return Base::getChecked(loc, name, /*opaque=*/true); +LLVMStructType +LLVMStructType::getOpaqueChecked(function_ref emitError, + MLIRContext *context, StringRef name) { + return Base::getChecked(emitError, context, name, /*opaque=*/true); } LogicalResult LLVMStructType::setBody(ArrayRef types, bool isPacked) { @@ -217,17 +227,17 @@ : getImpl()->getTypeList(); } -LogicalResult LLVMStructType::verifyConstructionInvariants(Location, StringRef, - bool) { +LogicalResult LLVMStructType::verify(function_ref, + StringRef, bool) { return success(); } -LogicalResult LLVMStructType::verifyConstructionInvariants(Location loc, - ArrayRef types, - bool) { +LogicalResult +LLVMStructType::verify(function_ref emitError, + ArrayRef types, bool) { for (Type t : types) if (!isValidElementType(t)) - return emitError(loc, "invalid LLVM structure element type: ") << t; + return emitError() << "invalid LLVM structure element type: " << t; return success(); } @@ -238,14 +248,14 @@ /// Verifies that the type about to be constructed is well-formed. template -static LogicalResult verifyVectorConstructionInvariants(Location loc, - Type elementType, - unsigned numElements) { +static LogicalResult +verifyVectorConstructionInvariants(function_ref emitError, + Type elementType, unsigned numElements) { if (numElements == 0) - return emitError(loc, "the number of vector elements must be positive"); + return emitError() << "the number of vector elements must be positive"; if (!VecTy::isValidElementType(elementType)) - return emitError(loc, "invalid vector element type"); + return emitError() << "invalid vector element type"; return success(); } @@ -256,11 +266,12 @@ return Base::get(elementType.getContext(), elementType, numElements); } -LLVMFixedVectorType LLVMFixedVectorType::getChecked(Location loc, - Type elementType, - unsigned numElements) { +LLVMFixedVectorType +LLVMFixedVectorType::getChecked(function_ref emitError, + Type elementType, unsigned numElements) { assert(elementType && "expected non-null subtype"); - return Base::getChecked(loc, elementType, numElements); + return Base::getChecked(emitError, elementType.getContext(), elementType, + numElements); } Type LLVMFixedVectorType::getElementType() { @@ -275,10 +286,11 @@ return type.isa(); } -LogicalResult LLVMFixedVectorType::verifyConstructionInvariants( - Location loc, Type elementType, unsigned numElements) { +LogicalResult +LLVMFixedVectorType::verify(function_ref emitError, + Type elementType, unsigned numElements) { return verifyVectorConstructionInvariants( - loc, elementType, numElements); + emitError, elementType, numElements); } //===----------------------------------------------------------------------===// @@ -292,10 +304,11 @@ } LLVMScalableVectorType -LLVMScalableVectorType::getChecked(Location loc, Type elementType, - unsigned minNumElements) { +LLVMScalableVectorType::getChecked(function_ref emitError, + Type elementType, unsigned minNumElements) { assert(elementType && "expected non-null subtype"); - return Base::getChecked(loc, elementType, minNumElements); + return Base::getChecked(emitError, elementType.getContext(), elementType, + minNumElements); } Type LLVMScalableVectorType::getElementType() { @@ -313,10 +326,11 @@ return isCompatibleFloatingPointType(type) || type.isa(); } -LogicalResult LLVMScalableVectorType::verifyConstructionInvariants( - Location loc, Type elementType, unsigned numElements) { +LogicalResult +LLVMScalableVectorType::verify(function_ref emitError, + Type elementType, unsigned numElements) { return verifyVectorConstructionInvariants( - loc, elementType, numElements); + emitError, elementType, numElements); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp --- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp @@ -28,20 +28,21 @@ return llvm::isa(type.getDialect()); } -LogicalResult QuantizedType::verifyConstructionInvariants( - Location loc, unsigned flags, Type storageType, Type expressedType, - int64_t storageTypeMin, int64_t storageTypeMax) { +LogicalResult +QuantizedType::verify(function_ref emitError, + unsigned flags, Type storageType, Type expressedType, + int64_t storageTypeMin, int64_t storageTypeMax) { // Verify that the storage type is integral. // This restriction may be lifted at some point in favor of using bf16 // or f16 as exact representations on hardware where that is advantageous. auto intStorageType = storageType.dyn_cast(); if (!intStorageType) - return emitError(loc, "storage type must be integral"); + return emitError() << "storage type must be integral"; unsigned integralWidth = intStorageType.getWidth(); // Verify storage width. if (integralWidth == 0 || integralWidth > MaxStorageBits) - return emitError(loc, "illegal storage type size: ") << integralWidth; + return emitError() << "illegal storage type size: " << integralWidth; // Verify storageTypeMin and storageTypeMax. bool isSigned = @@ -53,8 +54,8 @@ if (storageTypeMax - storageTypeMin <= 0 || storageTypeMin < defaultIntegerMin || storageTypeMax > defaultIntegerMax) { - return emitError(loc, "illegal storage min and storage max: (") - << storageTypeMin << ":" << storageTypeMax << ")"; + return emitError() << "illegal storage min and storage max: (" + << storageTypeMin << ":" << storageTypeMax << ")"; } return success(); } @@ -208,21 +209,22 @@ storageTypeMin, storageTypeMax); } -AnyQuantizedType AnyQuantizedType::getChecked(unsigned flags, Type storageType, - Type expressedType, - int64_t storageTypeMin, - int64_t storageTypeMax, - Location location) { - return Base::getChecked(location, flags, storageType, expressedType, - storageTypeMin, storageTypeMax); +AnyQuantizedType +AnyQuantizedType::getChecked(function_ref emitError, + unsigned flags, Type storageType, + Type expressedType, int64_t storageTypeMin, + int64_t storageTypeMax) { + return Base::getChecked(emitError, storageType.getContext(), flags, + storageType, expressedType, storageTypeMin, + storageTypeMax); } -LogicalResult AnyQuantizedType::verifyConstructionInvariants( - Location loc, unsigned flags, Type storageType, Type expressedType, - int64_t storageTypeMin, int64_t storageTypeMax) { - if (failed(QuantizedType::verifyConstructionInvariants( - loc, flags, storageType, expressedType, storageTypeMin, - storageTypeMax))) { +LogicalResult +AnyQuantizedType::verify(function_ref emitError, + unsigned flags, Type storageType, Type expressedType, + int64_t storageTypeMin, int64_t storageTypeMax) { + if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType, + storageTypeMin, storageTypeMax))) { return failure(); } @@ -230,7 +232,7 @@ // If this restriction is ever eliminated, the parser/printer must be // extended. if (expressedType && !expressedType.isa()) - return emitError(loc, "expressed type must be floating point"); + return emitError() << "expressed type must be floating point"; return success(); } @@ -244,39 +246,38 @@ scale, zeroPoint, storageTypeMin, storageTypeMax); } -UniformQuantizedType -UniformQuantizedType::getChecked(unsigned flags, Type storageType, - Type expressedType, double scale, - int64_t zeroPoint, int64_t storageTypeMin, - int64_t storageTypeMax, Location location) { - return Base::getChecked(location, flags, storageType, expressedType, scale, - zeroPoint, storageTypeMin, storageTypeMax); +UniformQuantizedType UniformQuantizedType::getChecked( + function_ref emitError, unsigned flags, + Type storageType, Type expressedType, double scale, int64_t zeroPoint, + int64_t storageTypeMin, int64_t storageTypeMax) { + return Base::getChecked(emitError, storageType.getContext(), flags, + storageType, expressedType, scale, zeroPoint, + storageTypeMin, storageTypeMax); } -LogicalResult UniformQuantizedType::verifyConstructionInvariants( - Location loc, unsigned flags, Type storageType, Type expressedType, - double scale, int64_t zeroPoint, int64_t storageTypeMin, - int64_t storageTypeMax) { - if (failed(QuantizedType::verifyConstructionInvariants( - loc, flags, storageType, expressedType, storageTypeMin, - storageTypeMax))) { +LogicalResult UniformQuantizedType::verify( + function_ref emitError, unsigned flags, + Type storageType, Type expressedType, double scale, int64_t zeroPoint, + int64_t storageTypeMin, int64_t storageTypeMax) { + if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType, + storageTypeMin, storageTypeMax))) { return failure(); } // Uniform quantization requires fully expressed parameters, including // expressed type. if (!expressedType) - return emitError(loc, "uniform quantization requires expressed type"); + return emitError() << "uniform quantization requires expressed type"; // Verify that the expressed type is floating point. // If this restriction is ever eliminated, the parser/printer must be // extended. if (!expressedType.isa()) - return emitError(loc, "expressed type must be floating point"); + return emitError() << "expressed type must be floating point"; // Verify scale. if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale)) - return emitError(loc, "illegal scale: ") << scale; + return emitError() << "illegal scale: " << scale; return success(); } @@ -298,46 +299,45 @@ } UniformQuantizedPerAxisType UniformQuantizedPerAxisType::getChecked( - unsigned flags, Type storageType, Type expressedType, - ArrayRef scales, ArrayRef zeroPoints, - int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax, - Location location) { - return Base::getChecked(location, flags, storageType, expressedType, scales, - zeroPoints, quantizedDimension, storageTypeMin, - storageTypeMax); + function_ref emitError, unsigned flags, + Type storageType, Type expressedType, ArrayRef scales, + ArrayRef zeroPoints, int32_t quantizedDimension, + int64_t storageTypeMin, int64_t storageTypeMax) { + return Base::getChecked(emitError, storageType.getContext(), flags, + storageType, expressedType, scales, zeroPoints, + quantizedDimension, storageTypeMin, storageTypeMax); } -LogicalResult UniformQuantizedPerAxisType::verifyConstructionInvariants( - Location loc, unsigned flags, Type storageType, Type expressedType, - ArrayRef scales, ArrayRef zeroPoints, - int32_t quantizedDimension, int64_t storageTypeMin, - int64_t storageTypeMax) { - if (failed(QuantizedType::verifyConstructionInvariants( - loc, flags, storageType, expressedType, storageTypeMin, - storageTypeMax))) { +LogicalResult UniformQuantizedPerAxisType::verify( + function_ref emitError, unsigned flags, + Type storageType, Type expressedType, ArrayRef scales, + ArrayRef zeroPoints, int32_t quantizedDimension, + int64_t storageTypeMin, int64_t storageTypeMax) { + if (failed(QuantizedType::verify(emitError, flags, storageType, expressedType, + storageTypeMin, storageTypeMax))) { return failure(); } // Uniform quantization requires fully expressed parameters, including // expressed type. if (!expressedType) - return emitError(loc, "uniform quantization requires expressed type"); + return emitError() << "uniform quantization requires expressed type"; // Verify that the expressed type is floating point. // If this restriction is ever eliminated, the parser/printer must be // extended. if (!expressedType.isa()) - return emitError(loc, "expressed type must be floating point"); + return emitError() << "expressed type must be floating point"; // Ensure that the number of scales and zeroPoints match. if (scales.size() != zeroPoints.size()) - return emitError(loc, "illegal number of scales and zeroPoints: ") - << scales.size() << ", " << zeroPoints.size(); + return emitError() << "illegal number of scales and zeroPoints: " + << scales.size() << ", " << zeroPoints.size(); // Verify scale. for (double scale : scales) { if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale)) - return emitError(loc, "illegal scale: ") << scale; + return emitError() << "illegal scale: " << scale; } return success(); @@ -360,22 +360,23 @@ return Base::get(expressedType.getContext(), expressedType, min, max); } -CalibratedQuantizedType CalibratedQuantizedType::getChecked(Type expressedType, - double min, - double max, - Location location) { - return Base::getChecked(location, expressedType, min, max); +CalibratedQuantizedType CalibratedQuantizedType::getChecked( + function_ref emitError, Type expressedType, + double min, double max) { + return Base::getChecked(emitError, expressedType.getContext(), expressedType, + min, max); } -LogicalResult CalibratedQuantizedType::verifyConstructionInvariants( - Location loc, Type expressedType, double min, double max) { +LogicalResult +CalibratedQuantizedType::verify(function_ref emitError, + Type expressedType, double min, double max) { // Verify that the expressed type is floating point. // If this restriction is ever eliminated, the parser/printer must be // extended. if (!expressedType.isa()) - return emitError(loc, "expressed type must be floating point"); + return emitError() << "expressed type must be floating point"; if (max <= min) - return emitError(loc, "illegal min and max: (") << min << ":" << max << ")"; + return emitError() << "illegal min and max: (" << min << ":" << max << ")"; return success(); } diff --git a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp --- a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp +++ b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp @@ -155,8 +155,9 @@ return nullptr; } - return AnyQuantizedType::getChecked(typeFlags, storageType, expressedType, - storageTypeMin, storageTypeMax, loc); + return AnyQuantizedType::getChecked(loc, typeFlags, storageType, + expressedType, storageTypeMin, + storageTypeMax); } static ParseResult parseQuantParams(DialectAsmParser &parser, double &scale, @@ -279,13 +280,13 @@ ArrayRef scalesRef(scales.begin(), scales.end()); ArrayRef zeroPointsRef(zeroPoints.begin(), zeroPoints.end()); return UniformQuantizedPerAxisType::getChecked( - typeFlags, storageType, expressedType, scalesRef, zeroPointsRef, - quantizedDimension, storageTypeMin, storageTypeMax, loc); + loc, typeFlags, storageType, expressedType, scalesRef, zeroPointsRef, + quantizedDimension, storageTypeMin, storageTypeMax); } - return UniformQuantizedType::getChecked(typeFlags, storageType, expressedType, - scales.front(), zeroPoints.front(), - storageTypeMin, storageTypeMax, loc); + return UniformQuantizedType::getChecked( + loc, typeFlags, storageType, expressedType, scales.front(), + zeroPoints.front(), storageTypeMin, storageTypeMax); } /// Parses an CalibratedQuantizedType. @@ -313,7 +314,7 @@ return nullptr; } - return CalibratedQuantizedType::getChecked(expressedType, min, max, loc); + return CalibratedQuantizedType::getChecked(loc, expressedType, min, max); } /// Parse a type registered to this dialect. diff --git a/mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp b/mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp --- a/mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp +++ b/mlir/lib/Dialect/Quant/Utils/FakeQuantSupport.cpp @@ -123,17 +123,17 @@ // 0.0s, so the scale is set to 1.0 and the tensor can be quantized to zero // points and dequantized to 0.0. if (std::fabs(rmax - rmin) < std::numeric_limits::epsilon()) { - return UniformQuantizedType::getChecked(flags, storageType, expressedType, - 1.0, qmin, qmin, qmax, loc); + return UniformQuantizedType::getChecked( + loc, flags, storageType, expressedType, 1.0, qmin, qmin, qmax); } double scale; int64_t nudgedZeroPoint; getNudgedScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint); - return UniformQuantizedType::getChecked(flags, storageType, expressedType, - scale, nudgedZeroPoint, qmin, qmax, - loc); + return UniformQuantizedType::getChecked(loc, flags, storageType, + expressedType, scale, nudgedZeroPoint, + qmin, qmax); } UniformQuantizedPerAxisType mlir::quant::fakeQuantAttrsToType( @@ -179,6 +179,6 @@ unsigned flags = isSigned ? QuantizationFlags::Signed : 0; return UniformQuantizedPerAxisType::getChecked( - flags, storageType, expressedType, scales, zeroPoints, quantizedDimension, - qmin, qmax, loc); + loc, flags, storageType, expressedType, scales, zeroPoints, + quantizedDimension, qmin, qmax); } diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVAttributes.cpp @@ -162,23 +162,23 @@ return llvm::None; } -LogicalResult spirv::InterfaceVarABIAttr::verifyConstructionInvariants( - Location loc, IntegerAttr descriptorSet, IntegerAttr binding, - IntegerAttr storageClass) { +LogicalResult spirv::InterfaceVarABIAttr::verify( + function_ref emitError, IntegerAttr descriptorSet, + IntegerAttr binding, IntegerAttr storageClass) { if (!descriptorSet.getType().isSignlessInteger(32)) - return emitError(loc, "expected 32-bit integer for descriptor set"); + return emitError() << "expected 32-bit integer for descriptor set"; if (!binding.getType().isSignlessInteger(32)) - return emitError(loc, "expected 32-bit integer for binding"); + return emitError() << "expected 32-bit integer for binding"; if (storageClass) { if (auto storageClassAttr = storageClass.cast()) { auto storageClassValue = spirv::symbolizeStorageClass(storageClassAttr.getInt()); if (!storageClassValue) - return emitError(loc, "unknown storage class"); + return emitError() << "unknown storage class"; } else { - return emitError(loc, "expected valid storage class"); + return emitError() << "expected valid storage class"; } } @@ -257,11 +257,12 @@ return getImpl()->capabilities.cast(); } -LogicalResult spirv::VerCapExtAttr::verifyConstructionInvariants( - Location loc, IntegerAttr version, ArrayAttr capabilities, - ArrayAttr extensions) { +LogicalResult +spirv::VerCapExtAttr::verify(function_ref emitError, + IntegerAttr version, ArrayAttr capabilities, + ArrayAttr extensions) { if (!version.getType().isSignlessInteger(32)) - return emitError(loc, "expected 32-bit integer for version"); + return emitError() << "expected 32-bit integer for version"; if (!llvm::all_of(capabilities.getValue(), [](Attribute attr) { if (auto intAttr = attr.dyn_cast()) @@ -269,7 +270,7 @@ return true; return false; })) - return emitError(loc, "unknown capability in capability list"); + return emitError() << "unknown capability in capability list"; if (!llvm::all_of(extensions.getValue(), [](Attribute attr) { if (auto strAttr = attr.dyn_cast()) @@ -277,7 +278,7 @@ return true; return false; })) - return emitError(loc, "unknown extension in extension list"); + return emitError() << "unknown extension in extension list"; return success(); } @@ -338,12 +339,14 @@ return getImpl()->limits.cast(); } -LogicalResult spirv::TargetEnvAttr::verifyConstructionInvariants( - Location loc, spirv::VerCapExtAttr /*triple*/, spirv::Vendor /*vendorID*/, - spirv::DeviceType /*deviceType*/, uint32_t /*deviceID*/, - DictionaryAttr limits) { +LogicalResult +spirv::TargetEnvAttr::verify(function_ref emitError, + spirv::VerCapExtAttr /*triple*/, + spirv::Vendor /*vendorID*/, + spirv::DeviceType /*deviceType*/, + uint32_t /*deviceID*/, DictionaryAttr limits) { if (!limits.isa()) - return emitError(loc, "expected spirv::ResourceLimitsAttr for limits"); + return emitError() << "expected spirv::ResourceLimitsAttr for limits"; return success(); } diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp @@ -730,17 +730,19 @@ return Base::get(imageType.getContext(), imageType); } -SampledImageType SampledImageType::getChecked(Type imageType, - Location location) { - return Base::getChecked(location, imageType); +SampledImageType +SampledImageType::getChecked(function_ref emitError, + Type imageType) { + return Base::getChecked(emitError, imageType.getContext(), imageType); } Type SampledImageType::getImageType() const { return getImpl()->imageType; } -LogicalResult SampledImageType::verifyConstructionInvariants(Location loc, - Type imageType) { +LogicalResult +SampledImageType::verify(function_ref emitError, + Type imageType) { if (!imageType.isa()) - return emitError(loc, "expected image type"); + return emitError() << "expected image type"; return success(); } @@ -1095,27 +1097,27 @@ return Base::get(columnType.getContext(), columnType, columnCount); } -MatrixType MatrixType::getChecked(Type columnType, uint32_t columnCount, - Location location) { - return Base::getChecked(location, columnType, columnCount); +MatrixType MatrixType::getChecked(function_ref emitError, + Type columnType, uint32_t columnCount) { + return Base::getChecked(emitError, columnType.getContext(), columnType, + columnCount); } -LogicalResult MatrixType::verifyConstructionInvariants(Location loc, - Type columnType, - uint32_t columnCount) { +LogicalResult MatrixType::verify(function_ref emitError, + Type columnType, uint32_t columnCount) { if (columnCount < 2 || columnCount > 4) - return emitError(loc, "matrix can have 2, 3, or 4 columns only"); + return emitError() << "matrix can have 2, 3, or 4 columns only"; if (!isValidColumnType(columnType)) - return emitError(loc, "matrix columns must be vectors of floats"); + return emitError() << "matrix columns must be vectors of floats"; /// The underlying vectors (columns) must be of size 2, 3, or 4 ArrayRef columnShape = columnType.cast().getShape(); if (columnShape.size() != 1) - return emitError(loc, "matrix columns must be 1D vectors"); + return emitError() << "matrix columns must be 1D vectors"; if (columnShape[0] < 2 || columnShape[0] > 4) - return emitError(loc, "matrix columns must be of size 2, 3, or 4"); + return emitError() << "matrix columns must be of size 2, 3, or 4"; return success(); } diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -211,16 +211,18 @@ return Base::get(type.getContext(), type, value); } -FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) { - return Base::getChecked(loc, type, value); +FloatAttr FloatAttr::getChecked(function_ref emitError, + Type type, double value) { + return Base::getChecked(emitError, type.getContext(), type, value); } FloatAttr FloatAttr::get(Type type, const APFloat &value) { return Base::get(type.getContext(), type, value); } -FloatAttr FloatAttr::getChecked(Type type, const APFloat &value, Location loc) { - return Base::getChecked(loc, type, value); +FloatAttr FloatAttr::getChecked(function_ref emitError, + Type type, const APFloat &value) { + return Base::getChecked(emitError, type.getContext(), type, value); } APFloat FloatAttr::getValue() const { return getImpl()->getValue(); } @@ -238,27 +240,29 @@ } /// Verify construction invariants. -static LogicalResult verifyFloatTypeInvariants(Location loc, Type type) { +static LogicalResult +verifyFloatTypeInvariants(function_ref emitError, + Type type) { if (!type.isa()) - return emitError(loc, "expected floating point type"); + return emitError() << "expected floating point type"; return success(); } -LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type, - double value) { - return verifyFloatTypeInvariants(loc, type); +LogicalResult FloatAttr::verify(function_ref emitError, + Type type, double value) { + return verifyFloatTypeInvariants(emitError, type); } -LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type, - const APFloat &value) { +LogicalResult FloatAttr::verify(function_ref emitError, + Type type, const APFloat &value) { // Verify that the type is correct. - if (failed(verifyFloatTypeInvariants(loc, type))) + if (failed(verifyFloatTypeInvariants(emitError, type))) return failure(); // Verify that the type semantics match that of the value. if (&type.cast().getFloatSemantics() != &value.getSemantics()) { - return emitError( - loc, "FloatAttr type doesn't match the type implied by its value"); + return emitError() + << "FloatAttr type doesn't match the type implied by its value"; } return success(); } @@ -326,26 +330,28 @@ return getValue().getZExtValue(); } -static LogicalResult verifyIntegerTypeInvariants(Location loc, Type type) { +static LogicalResult +verifyIntegerTypeInvariants(function_ref emitError, + Type type) { if (type.isa()) return success(); - return emitError(loc, "expected integer or index type"); + return emitError() << "expected integer or index type"; } -LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type, - int64_t value) { - return verifyIntegerTypeInvariants(loc, type); +LogicalResult IntegerAttr::verify(function_ref emitError, + Type type, int64_t value) { + return verifyIntegerTypeInvariants(emitError, type); } -LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type, - const APInt &value) { - if (failed(verifyIntegerTypeInvariants(loc, type))) +LogicalResult IntegerAttr::verify(function_ref emitError, + Type type, const APInt &value) { + if (failed(verifyIntegerTypeInvariants(emitError, type))) return failure(); if (auto integerType = type.dyn_cast()) if (integerType.getWidth() != value.getBitWidth()) - return emitError(loc, "integer type bit width (") - << integerType.getWidth() << ") doesn't match value bit width (" - << value.getBitWidth() << ")"; + return emitError() << "integer type bit width (" << integerType.getWidth() + << ") doesn't match value bit width (" + << value.getBitWidth() << ")"; return success(); } @@ -381,9 +387,11 @@ return Base::get(context, dialect, attrData, type); } -OpaqueAttr OpaqueAttr::getChecked(Identifier dialect, StringRef attrData, - Type type, Location location) { - return Base::getChecked(location, dialect, attrData, type); +OpaqueAttr OpaqueAttr::getChecked(function_ref emitError, + Identifier dialect, StringRef attrData, + Type type) { + return Base::getChecked(emitError, dialect.getContext(), dialect, attrData, + type); } /// Returns the dialect namespace of the opaque attribute. @@ -395,12 +403,11 @@ StringRef OpaqueAttr::getAttrData() const { return getImpl()->attrData; } /// Verify the construction of an opaque attribute. -LogicalResult OpaqueAttr::verifyConstructionInvariants(Location loc, - Identifier dialect, - StringRef attrData, - Type type) { +LogicalResult OpaqueAttr::verify(function_ref emitError, + Identifier dialect, StringRef attrData, + Type type) { if (!Dialect::isValidNamespace(dialect.strref())) - return emitError(loc, "invalid dialect namespace '") << dialect << "'"; + return emitError() << "invalid dialect namespace '" << dialect << "'"; return success(); } diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -32,10 +32,10 @@ //===----------------------------------------------------------------------===// /// Verify the construction of an integer type. -LogicalResult ComplexType::verifyConstructionInvariants(Location loc, - Type elementType) { +LogicalResult ComplexType::verify(function_ref emitError, + Type elementType) { if (!elementType.isIntOrFloat()) - return emitError(loc, "invalid element type for complex"); + return emitError() << "invalid element type for complex"; return success(); } @@ -47,12 +47,12 @@ constexpr unsigned IntegerType::kMaxWidth; /// Verify the construction of an integer type. -LogicalResult -IntegerType::verifyConstructionInvariants(Location loc, unsigned width, - SignednessSemantics signedness) { +LogicalResult IntegerType::verify(function_ref emitError, + unsigned width, + SignednessSemantics signedness) { if (width > IntegerType::kMaxWidth) { - return emitError(loc) << "integer bitwidth is limited to " - << IntegerType::kMaxWidth << " bits"; + return emitError() << "integer bitwidth is limited to " + << IntegerType::kMaxWidth << " bits"; } return success(); } @@ -183,11 +183,10 @@ //===----------------------------------------------------------------------===// /// Verify the construction of an opaque type. -LogicalResult OpaqueType::verifyConstructionInvariants(Location loc, - Identifier dialect, - StringRef typeData) { +LogicalResult OpaqueType::verify(function_ref emitError, + Identifier dialect, StringRef typeData) { if (!Dialect::isValidNamespace(dialect.strref())) - return emitError(loc, "invalid dialect namespace '") << dialect << "'"; + return emitError() << "invalid dialect namespace '" << dialect << "'"; return success(); } @@ -362,22 +361,22 @@ return Base::get(elementType.getContext(), shape, elementType); } -VectorType VectorType::getChecked(Location location, ArrayRef shape, - Type elementType) { - return Base::getChecked(location, shape, elementType); +VectorType VectorType::getChecked(function_ref emitError, + ArrayRef shape, Type elementType) { + return Base::getChecked(emitError, elementType.getContext(), shape, + elementType); } -LogicalResult VectorType::verifyConstructionInvariants(Location loc, - ArrayRef shape, - Type elementType) { +LogicalResult VectorType::verify(function_ref emitError, + ArrayRef shape, Type elementType) { if (shape.empty()) - return emitError(loc, "vector types must have at least one dimension"); + return emitError() << "vector types must have at least one dimension"; if (!isValidElementType(elementType)) - return emitError(loc, "vector elements must be int or float type"); + return emitError() << "vector elements must be int or float type"; if (any_of(shape, [](int64_t i) { return i <= 0; })) - return emitError(loc, "vector types must have positive constant sizes"); + return emitError() << "vector types must have positive constant sizes"; return success(); } @@ -400,12 +399,12 @@ // TensorType //===----------------------------------------------------------------------===// -// Check if "elementType" can be an element type of a tensor. Emit errors if -// location is not nullptr. Returns failure if check failed. -static LogicalResult checkTensorElementType(Location location, - Type elementType) { +// Check if "elementType" can be an element type of a tensor. +static LogicalResult +checkTensorElementType(function_ref emitError, + Type elementType) { if (!TensorType::isValidElementType(elementType)) - return emitError(location, "invalid tensor element type: ") << elementType; + return emitError() << "invalid tensor element type: " << elementType; return success(); } @@ -428,19 +427,21 @@ return Base::get(elementType.getContext(), shape, elementType); } -RankedTensorType RankedTensorType::getChecked(Location location, - ArrayRef shape, - Type elementType) { - return Base::getChecked(location, shape, elementType); +RankedTensorType +RankedTensorType::getChecked(function_ref emitError, + ArrayRef shape, Type elementType) { + return Base::getChecked(emitError, elementType.getContext(), shape, + elementType); } -LogicalResult RankedTensorType::verifyConstructionInvariants( - Location loc, ArrayRef shape, Type elementType) { +LogicalResult +RankedTensorType::verify(function_ref emitError, + ArrayRef shape, Type elementType) { for (int64_t s : shape) { if (s < -1) - return emitError(loc, "invalid tensor dimension size"); + return emitError() << "invalid tensor dimension size"; } - return checkTensorElementType(loc, elementType); + return checkTensorElementType(emitError, elementType); } ArrayRef RankedTensorType::getShape() const { @@ -455,15 +456,16 @@ return Base::get(elementType.getContext(), elementType); } -UnrankedTensorType UnrankedTensorType::getChecked(Location location, - Type elementType) { - return Base::getChecked(location, elementType); +UnrankedTensorType +UnrankedTensorType::getChecked(function_ref emitError, + Type elementType) { + return Base::getChecked(emitError, elementType.getContext(), elementType); } LogicalResult -UnrankedTensorType::verifyConstructionInvariants(Location loc, - Type elementType) { - return checkTensorElementType(loc, elementType); +UnrankedTensorType::verify(function_ref emitError, + Type elementType) { + return checkTensorElementType(emitError, elementType); } //===----------------------------------------------------------------------===// @@ -485,8 +487,10 @@ MemRefType MemRefType::get(ArrayRef shape, Type elementType, ArrayRef affineMapComposition, unsigned memorySpace) { - auto result = getImpl(shape, elementType, affineMapComposition, memorySpace, - /*location=*/llvm::None); + auto result = + getImpl(shape, elementType, affineMapComposition, memorySpace, [=] { + return emitError(UnknownLoc::get(elementType.getContext())); + }); assert(result && "Failed to construct instance of MemRefType."); return result; } @@ -497,12 +501,12 @@ /// UnknownLoc. If the MemRefType defined by the arguments would be /// ill-formed, emits errors (to the handler registered with the context or to /// the error stream) and returns nullptr. -MemRefType MemRefType::getChecked(Location location, ArrayRef shape, - Type elementType, +MemRefType MemRefType::getChecked(function_ref emitError, + ArrayRef shape, Type elementType, ArrayRef affineMapComposition, unsigned memorySpace) { return getImpl(shape, elementType, affineMapComposition, memorySpace, - location); + emitError); } /// Get or create a new MemRefType defined by the arguments. If the resulting @@ -512,18 +516,16 @@ MemRefType MemRefType::getImpl(ArrayRef shape, Type elementType, ArrayRef affineMapComposition, unsigned memorySpace, - Optional location) { + function_ref emitError) { auto *context = elementType.getContext(); if (!BaseMemRefType::isValidElementType(elementType)) - return (void)emitOptionalError(location, "invalid memref element type"), - MemRefType(); + return (emitError() << "invalid memref element type", MemRefType()); for (int64_t s : shape) { // Negative sizes are not allowed except for `-1` that means dynamic size. if (s < -1) - return (void)emitOptionalError(location, "invalid memref size"), - MemRefType(); + return (emitError() << "invalid memref size", MemRefType()); } // Check that the structure of the composition is valid, i.e. that each @@ -533,12 +535,10 @@ unsigned i = 0; for (const auto &affineMap : affineMapComposition) { if (affineMap.getNumDims() != dim) { - if (location) - emitError(*location) - << "memref affine map dimension mismatch between " - << (i == 0 ? Twine("memref rank") : "affine map " + Twine(i)) - << " and affine map" << i + 1 << ": " << dim - << " != " << affineMap.getNumDims(); + emitError() << "memref affine map dimension mismatch between " + << (i == 0 ? Twine("memref rank") : "affine map " + Twine(i)) + << " and affine map" << i + 1 << ": " << dim + << " != " << affineMap.getNumDims(); return nullptr; } @@ -575,17 +575,18 @@ return Base::get(elementType.getContext(), elementType, memorySpace); } -UnrankedMemRefType UnrankedMemRefType::getChecked(Location location, - Type elementType, - unsigned memorySpace) { - return Base::getChecked(location, elementType, memorySpace); +UnrankedMemRefType +UnrankedMemRefType::getChecked(function_ref emitError, + Type elementType, unsigned memorySpace) { + return Base::getChecked(emitError, elementType.getContext(), elementType, + memorySpace); } LogicalResult -UnrankedMemRefType::verifyConstructionInvariants(Location loc, Type elementType, - unsigned memorySpace) { +UnrankedMemRefType::verify(function_ref emitError, + Type elementType, unsigned memorySpace) { if (!BaseMemRefType::isValidElementType(elementType)) - return emitError(loc, "invalid memref element type"); + return emitError() << "invalid memref element type"; return success(); } diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -856,12 +856,13 @@ return Base::get(context, width, signedness); } -IntegerType IntegerType::getChecked(Location location, unsigned width, - SignednessSemantics signedness) { - if (auto cached = - getCachedIntegerType(width, signedness, location->getContext())) +IntegerType +IntegerType::getChecked(function_ref emitError, + MLIRContext *context, unsigned width, + SignednessSemantics signedness) { + if (auto cached = getCachedIntegerType(width, signedness, context)) return cached; - return Base::getChecked(location, width, signedness); + return Base::getChecked(emitError, context, width, signedness); } /// Get an instance of the NoneType. @@ -1005,11 +1006,14 @@ // StorageUniquerSupport //===----------------------------------------------------------------------===// -/// Utility method to generate a default location for use when checking the -/// construction invariants of a storage object. This is defined out-of-line to -/// avoid the need to include Location.h. -const AttributeStorage * -mlir::detail::generateUnknownStorageLocation(MLIRContext *ctx) { - return reinterpret_cast( - ctx->getImpl().unknownLocAttr.getAsOpaquePointer()); +/// Utility method to generate a callback that can be used to generate a +/// diagnostic when checking the construction invariants of a storage object. +/// This is defined out-of-line to avoid the need to include Location.h. +llvm::unique_function +mlir::detail::getDefaultDiagnosticEmitFn(MLIRContext *ctx) { + return [ctx] { return emitError(UnknownLoc::get(ctx)); }; +} +llvm::unique_function +mlir::detail::getDefaultDiagnosticEmitFn(const Location &loc) { + return [=] { return emitError(loc); }; } diff --git a/mlir/lib/Parser/DialectSymbolParser.cpp b/mlir/lib/Parser/DialectSymbolParser.cpp --- a/mlir/lib/Parser/DialectSymbolParser.cpp +++ b/mlir/lib/Parser/DialectSymbolParser.cpp @@ -524,9 +524,9 @@ // Otherwise, form a new opaque attribute. return OpaqueAttr::getChecked( + getEncodedSourceLocation(loc), Identifier::get(dialectName, state.context), symbolData, - attrType ? attrType : NoneType::get(state.context), - getEncodedSourceLocation(loc)); + attrType ? attrType : NoneType::get(state.context)); }); // Ensure that the attribute has the same type as requested. @@ -563,7 +563,7 @@ // Otherwise, form a new opaque type. return OpaqueType::getChecked( - getEncodedSourceLocation(loc), + getEncodedSourceLocation(loc), state.context, Identifier::get(dialectName, state.context), symbolData); }); } diff --git a/mlir/lib/TableGen/TypeDef.cpp b/mlir/lib/TableGen/TypeDef.cpp --- a/mlir/lib/TableGen/TypeDef.cpp +++ b/mlir/lib/TableGen/TypeDef.cpp @@ -23,13 +23,6 @@ // TypeBuilder //===----------------------------------------------------------------------===// -/// Return an optional code body used for the `getChecked` variant of this -/// builder. -Optional TypeBuilder::getCheckedBody() const { - Optional body = def->getValueAsOptionalString("checkedBody"); - return body && !body->empty() ? body : llvm::None; -} - /// Returns true if this builder is able to infer the MLIRContext parameter. bool TypeBuilder::hasInferredContextParameter() const { return def->getValueAsBit("hasInferredContextParam"); @@ -111,8 +104,8 @@ bool TypeDef::genAccessors() const { return def->getValueAsBit("genAccessors"); } -bool TypeDef::genVerifyInvariantsDecl() const { - return def->getValueAsBit("genVerifyInvariantsDecl"); +bool TypeDef::genVerifyDecl() const { + return def->getValueAsBit("genVerifyDecl"); } llvm::Optional TypeDef::getExtraDecls() const { auto value = def->getValueAsString("extraClassDeclaration"); diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td --- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td +++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td @@ -48,7 +48,7 @@ // An example of how one could implement a standard integer. def IntegerType : Test_Type<"TestInteger"> { let mnemonic = "int"; - let genVerifyInvariantsDecl = 1; + let genVerifyDecl = 1; let parameters = ( ins "unsigned":$width, @@ -67,9 +67,7 @@ let builders = [ TypeBuilder<(ins "unsigned":$width, CArg<"SignednessSemantics", "Signless">:$signedness), [{ - return Base::get($_ctxt, width, signedness); - }], [{ - return Base::getChecked($_loc, width, signedness); + return $_get($_ctxt, width, signedness); }]> ]; let skipDefaultBuilders = 1; @@ -84,7 +82,7 @@ if ($_parser.parseInteger(width)) return Type(); if ($_parser.parseGreater()) return Type(); Location loc = $_parser.getEncodedSourceLoc($_parser.getNameLoc()); - return getChecked(loc, width, signedness); + return getChecked(loc, loc.getContext(), width, signedness); }]; // Any extra code one wants in the type's class declaration. diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp --- a/mlir/test/lib/Dialect/Test/TestTypes.cpp +++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp @@ -112,8 +112,10 @@ } // Example type validity checker. -LogicalResult TestIntegerType::verifyConstructionInvariants( - Location loc, unsigned width, TestIntegerType::SignednessSemantics ss) { +LogicalResult +TestIntegerType::verify(function_ref emitError, + unsigned width, + TestIntegerType::SignednessSemantics ss) { if (width > 8) return failure(); return success(); diff --git a/mlir/test/mlir-tblgen/typedefs.td b/mlir/test/mlir-tblgen/typedefs.td --- a/mlir/test/mlir-tblgen/typedefs.td +++ b/mlir/test/mlir-tblgen/typedefs.td @@ -54,11 +54,11 @@ RTLValueType:$inner ); - let genVerifyInvariantsDecl = 1; + let genVerifyDecl = 1; // DECL-LABEL: class CompoundAType : public ::mlir::Type -// DECL: static CompoundAType getChecked(::mlir::Location loc, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef dims, ::mlir::Type inner); -// DECL: static ::mlir::LogicalResult verifyConstructionInvariants(::mlir::Location loc, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef dims, ::mlir::Type inner); +// DECL: static CompoundAType getChecked(llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::MLIRContext *context, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef dims, ::mlir::Type inner); +// DECL: static ::mlir::LogicalResult verify(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef dims, ::mlir::Type inner); // DECL: static ::llvm::StringRef getMnemonic() { return "cmpnd_a"; } // DECL: static ::mlir::Type parse(::mlir::MLIRContext *context, // DECL-NEXT: ::mlir::DialectAsmParser &parser); @@ -95,7 +95,7 @@ def E_IntegerType : TestType<"Integer"> { let mnemonic = "int"; - let genVerifyInvariantsDecl = 1; + let genVerifyDecl = 1; let parameters = ( ins "SignednessSemantics":$signedness, diff --git a/mlir/tools/mlir-tblgen/TypeDefGen.cpp b/mlir/tools/mlir-tblgen/TypeDefGen.cpp --- a/mlir/tools/mlir-tblgen/TypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/TypeDefGen.cpp @@ -182,27 +182,29 @@ void print(::mlir::DialectAsmPrinter &printer) const; )"; -/// The code block for the verifyConstructionInvariants and getChecked. +/// The code block for the verify method declaration. /// -/// {0}: The name of the typeDef class. -/// {1}: List of parameters, parameters style. +/// {0}: List of parameters, parameters style. static const char *const typeDefDeclVerifyStr = R"( - static ::mlir::LogicalResult verifyConstructionInvariants(::mlir::Location loc{1}); + using Base::getChecked; + static ::mlir::LogicalResult verify(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError{0}); )"; /// Emit the builders for the given type. static void emitTypeBuilderDecls(const TypeDef &typeDef, raw_ostream &os, TypeParamCommaFormatter ¶mTypes) { StringRef typeClass = typeDef.getCppClassName(); - bool genCheckedMethods = typeDef.genVerifyInvariantsDecl(); + bool genCheckedMethods = typeDef.genVerifyDecl(); if (!typeDef.skipDefaultBuilders()) { os << llvm::formatv( " static {0} get(::mlir::MLIRContext *context{1});\n", typeClass, paramTypes); if (genCheckedMethods) { - os << llvm::formatv( - " static {0} getChecked(::mlir::Location loc{1});\n", typeClass, - paramTypes); + os << llvm::formatv(" static {0} " + "getChecked(llvm::function_ref<::mlir::" + "InFlightDiagnostic()> emitError, " + "::mlir::MLIRContext *context{1});\n", + typeClass, paramTypes); } } @@ -231,10 +233,14 @@ // Generate the `getChecked` variant of the builder. if (genCheckedMethods) { - os << " static " << typeClass << " getChecked(::mlir::Location loc"; + os << " static " << typeClass + << " getChecked(llvm::function_ref " + "emitError"; + if (!builder.hasInferredContextParameter()) + os << ", ::mlir::MLIRContext *context"; if (!paramStr.empty()) - os << ", " << paramStr; - os << ");\n"; + os << ", "; + os << paramStr << ");\n"; } } } @@ -265,9 +271,8 @@ emitTypeBuilderDecls(typeDef, os, emitTypeNamePairsAfterComma); // Emit the verify invariants declaration. - if (typeDef.genVerifyInvariantsDecl()) - os << llvm::formatv(typeDefDeclVerifyStr, typeDef.getCppClassName(), - emitTypeNamePairsAfterComma); + if (typeDef.genVerifyDecl()) + os << llvm::formatv(typeDefDeclVerifyStr, emitTypeNamePairsAfterComma); } // Emit the mnenomic, if specified. @@ -515,10 +520,18 @@ } } +/// Replace all instances of 'from' to 'to' in `str` and return the new string. +static std::string replaceInStr(std::string str, StringRef from, StringRef to) { + size_t pos = 0; + while ((pos = str.find(from.data(), pos, from.size())) != std::string::npos) + str.replace(pos, from.size(), to.data(), to.size()); + return str; +} + /// Emit the builders for the given type. static void emitTypeBuilderDefs(const TypeDef &typeDef, raw_ostream &os, ArrayRef typeDefParams) { - bool genCheckedMethods = typeDef.genVerifyInvariantsDecl(); + bool genCheckedMethods = typeDef.genVerifyDecl(); StringRef typeClass = typeDef.getCppClassName(); if (!typeDef.skipDefaultBuilders()) { os << llvm::formatv( @@ -531,8 +544,10 @@ typeDefParams)); if (genCheckedMethods) { os << llvm::formatv( - "{0} {0}::getChecked(::mlir::Location loc{1}) {{\n" - " return Base::getChecked(loc{2});\n}\n", + "{0} {0}::getChecked(" + "llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, " + "::mlir::MLIRContext *context{1}) {{\n" + " return Base::getChecked(emitError, context{2});\n}\n", typeClass, TypeParamCommaFormatter( TypeParamCommaFormatter::EmitFormat::TypeNamePairs, @@ -542,16 +557,15 @@ } } + auto builderFmtCtx = + FmtContext().addSubst("_ctxt", "context").addSubst("_get", "Base::get"); + auto inferredCtxBuilderFmtCtx = FmtContext().addSubst("_get", "Base::get"); + auto checkedBuilderFmtCtx = FmtContext().addSubst("_ctxt", "context"); + // Generate the builders specified by the user. - auto builderFmtCtx = FmtContext().addSubst("_ctxt", "context"); - auto checkedBuilderFmtCtx = FmtContext() - .addSubst("_loc", "loc") - .addSubst("_ctxt", "loc.getContext()"); for (const TypeBuilder &builder : typeDef.getBuilders()) { Optional body = builder.getBody(); - Optional checkedBody = - genCheckedMethods ? builder.getCheckedBody() : llvm::None; - if (!body && !checkedBody) + if (!body) continue; std::string paramStr; llvm::raw_string_ostream paramOS(paramStr); @@ -565,27 +579,33 @@ paramOS.flush(); // Emit the `get` variant of the builder. - if (body) { - os << llvm::formatv("{0} {0}::get(", typeClass); - if (!builder.hasInferredContextParameter()) { - os << "::mlir::MLIRContext *context"; - if (!paramStr.empty()) - os << ", "; - os << llvm::formatv("{0}) {{\n {1};\n}\n", paramStr, - tgfmt(*body, &builderFmtCtx).str()); - } else { - os << llvm::formatv("{0}) {{\n {1};\n}\n", paramStr, *body); - } + os << llvm::formatv("{0} {0}::get(", typeClass); + if (!builder.hasInferredContextParameter()) { + os << "::mlir::MLIRContext *context"; + if (!paramStr.empty()) + os << ", "; + os << llvm::formatv("{0}) {{\n {1};\n}\n", paramStr, + tgfmt(*body, &builderFmtCtx).str()); + } else { + os << llvm::formatv("{0}) {{\n {1};\n}\n", paramStr, + tgfmt(*body, &inferredCtxBuilderFmtCtx).str()); } // Emit the `getChecked` variant of the builder. - if (checkedBody) { - os << llvm::formatv("{0} {0}::getChecked(::mlir::Location loc", + if (genCheckedMethods) { + os << llvm::formatv("{0} " + "{0}::getChecked(llvm::function_ref<::mlir::" + "InFlightDiagnostic()> emitErrorFn", typeClass); + std::string checkedBody = + replaceInStr(body->str(), "$_get(", "Base::getChecked(emitErrorFn, "); + if (!builder.hasInferredContextParameter()) { + os << ", ::mlir::MLIRContext *context"; + checkedBody = tgfmt(checkedBody, &checkedBuilderFmtCtx).str(); + } if (!paramStr.empty()) - os << ", " << paramStr; - os << llvm::formatv(") {{\n {0};\n}\n", - tgfmt(*checkedBody, &checkedBuilderFmtCtx)); + os << ", "; + os << llvm::formatv("{0}) {{\n {1};\n}\n", paramStr, checkedBody); } } }