diff --git a/mlir/docs/DefiningAttributesAndTypes.md b/mlir/docs/DefiningAttributesAndTypes.md --- a/mlir/docs/DefiningAttributesAndTypes.md +++ b/mlir/docs/DefiningAttributesAndTypes.md @@ -194,42 +194,34 @@ /// This method is used to get an instance of the 'ComplexType'. This method /// asserts that all of the construction invariants were satisfied. To /// gracefully handle failed construction, getChecked should be used instead. - static ComplexType get(MLIRContext *context, unsigned param, Type type) { + static ComplexType get(unsigned param, Type type) { // Call into a helper 'get' method in 'TypeBase' to get a uniqued instance // of this type. All parameters to the storage class are passed after the // type kind. - return Base::get(context, MyTypes::Complex, param, type); + return Base::get(type.getContext(), MyTypes::Complex, param, type); } /// This method is used to get an instance of the 'ComplexType', defined at /// the given location. If any of the construction invariants are invalid, /// errors are emitted with the provided location and a null type is returned. /// Note: This method is completely optional. - static ComplexType getChecked(MLIRContext *context, unsigned param, Type type, - Location location) { + static ComplexType getChecked(unsigned param, Type type, Location location) { // Call into a helper 'getChecked' method in 'TypeBase' to get a uniqued // instance of this type. All parameters to the storage class are passed // after the type kind. - return Base::getChecked(location, context, MyTypes::Complex, param, type); + return Base::getChecked(location, MyTypes::Complex, param, type); } /// This method is used to verify the construction invariants passed into the /// 'get' and 'getChecked' methods. Note: This method is completely optional. static LogicalResult verifyConstructionInvariants( - llvm::Optional loc, MLIRContext *context, unsigned param, - Type type) { + Location loc, unsigned param, Type type) { // Our type only allows non-zero parameters. - if (param == 0) { - if (loc) - context->emitError(loc) << "non-zero parameter passed to 'ComplexType'"; - return failure(); - } + if (param == 0) + return emitError(loc) << "non-zero parameter passed to 'ComplexType'"; // Our type also expects an integer type. - if (!type.isa()) { - if (loc) - context->emitError(loc) << "non integer-type passed to 'ComplexType'"; - return failure(); - } + if (!type.isa()) + return emitError(loc) << "non integer-type passed to 'ComplexType'"; return success(); } diff --git a/mlir/include/mlir/Dialect/QuantOps/QuantTypes.h b/mlir/include/mlir/Dialect/QuantOps/QuantTypes.h --- a/mlir/include/mlir/Dialect/QuantOps/QuantTypes.h +++ b/mlir/include/mlir/Dialect/QuantOps/QuantTypes.h @@ -66,8 +66,7 @@ static constexpr unsigned MaxStorageBits = 32; static LogicalResult - verifyConstructionInvariants(Optional loc, MLIRContext *context, - unsigned flags, Type storageType, + verifyConstructionInvariants(Location loc, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax); @@ -229,8 +228,7 @@ /// Verifies construction invariants and issues errors/warnings. static LogicalResult - verifyConstructionInvariants(Optional loc, MLIRContext *context, - unsigned flags, Type storageType, + verifyConstructionInvariants(Location loc, unsigned flags, Type storageType, Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax); }; @@ -288,10 +286,11 @@ Location location); /// Verifies construction invariants and issues errors/warnings. - static LogicalResult verifyConstructionInvariants( - Optional loc, MLIRContext *context, unsigned flags, - Type storageType, Type expressedType, double scale, int64_t zeroPoint, - int64_t storageTypeMin, int64_t storageTypeMax); + static LogicalResult + verifyConstructionInvariants(Location loc, unsigned flags, Type storageType, + Type expressedType, double scale, + int64_t zeroPoint, int64_t storageTypeMin, + int64_t storageTypeMax); /// Support method to enable LLVM-style type casting. static bool kindof(unsigned kind) { @@ -351,11 +350,12 @@ int64_t storageTypeMax, Location location); /// Verifies construction invariants and issues errors/warnings. - static LogicalResult verifyConstructionInvariants( - Optional loc, MLIRContext *context, unsigned flags, - Type storageType, Type expressedType, ArrayRef scales, - ArrayRef zeroPoints, int32_t quantizedDimension, - int64_t storageTypeMin, int64_t storageTypeMax); + static LogicalResult + verifyConstructionInvariants(Location loc, unsigned flags, Type storageType, + Type expressedType, ArrayRef scales, + ArrayRef zeroPoints, + int32_t quantizedDimension, + int64_t storageTypeMin, int64_t storageTypeMax); /// Support method to enable LLVM-style type casting. static bool kindof(unsigned kind) { diff --git a/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h b/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h --- a/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h +++ b/mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h @@ -90,10 +90,11 @@ static bool kindof(unsigned kind) { return kind == AttrKind::TargetEnv; } - static LogicalResult - verifyConstructionInvariants(Optional loc, MLIRContext *context, - IntegerAttr version, ArrayAttr extensions, - ArrayAttr capabilities, DictionaryAttr limits); + static LogicalResult verifyConstructionInvariants(Location loc, + IntegerAttr version, + ArrayAttr extensions, + ArrayAttr capabilities, + DictionaryAttr limits); }; /// Returns the attribute name for specifying argument ABI information. diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -330,11 +330,9 @@ } /// Verify the construction invariants for a double value. - static LogicalResult verifyConstructionInvariants(Optional loc, - MLIRContext *ctx, Type type, + static LogicalResult verifyConstructionInvariants(Location loc, Type type, double value); - static LogicalResult verifyConstructionInvariants(Optional loc, - MLIRContext *ctx, Type type, + static LogicalResult verifyConstructionInvariants(Location loc, Type type, const APFloat &value); }; @@ -361,11 +359,9 @@ return kind == StandardAttributes::Integer; } - static LogicalResult verifyConstructionInvariants(Optional loc, - MLIRContext *ctx, Type type, + static LogicalResult verifyConstructionInvariants(Location loc, Type type, int64_t value); - static LogicalResult verifyConstructionInvariants(Optional loc, - MLIRContext *ctx, Type type, + static LogicalResult verifyConstructionInvariants(Location loc, Type type, const APInt &value); }; @@ -419,8 +415,7 @@ StringRef getAttrData() const; /// Verify the construction of an opaque attribute. - static LogicalResult verifyConstructionInvariants(Optional loc, - MLIRContext *context, + static LogicalResult verifyConstructionInvariants(Location loc, Identifier dialect, StringRef attrData, Type type); diff --git a/mlir/include/mlir/IR/Location.h b/mlir/include/mlir/IR/Location.h --- a/mlir/include/mlir/IR/Location.h +++ b/mlir/include/mlir/IR/Location.h @@ -54,6 +54,12 @@ Location(LocationAttr loc) : impl(loc) { assert(loc && "location should never be null."); } + Location(const LocationAttr::ImplType *impl) : impl(impl) { + assert(impl && "location should never be null."); + } + + /// Return the context this location is uniqued in. + MLIRContext *getContext() const { return impl.getContext(); } /// Access the impl location attribute. operator LocationAttr() const { return impl; } diff --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h --- a/mlir/include/mlir/IR/StandardTypes.h +++ b/mlir/include/mlir/IR/StandardTypes.h @@ -96,8 +96,7 @@ Location location); /// Verify the construction of an integer type. - static LogicalResult verifyConstructionInvariants(Optional loc, - MLIRContext *context, + static LogicalResult verifyConstructionInvariants(Location loc, unsigned width); /// Return the bitwidth of this integer type. @@ -162,8 +161,7 @@ static ComplexType getChecked(Type elementType, Location location); /// Verify the construction of an integer type. - static LogicalResult verifyConstructionInvariants(Optional loc, - MLIRContext *context, + static LogicalResult verifyConstructionInvariants(Location loc, Type elementType); Type getElementType(); @@ -270,8 +268,7 @@ Location location); /// Verify the construction of a vector type. - static LogicalResult verifyConstructionInvariants(Optional loc, - MLIRContext *context, + static LogicalResult verifyConstructionInvariants(Location loc, ArrayRef shape, Type elementType); @@ -329,8 +326,7 @@ Location location); /// Verify the construction of a ranked tensor type. - static LogicalResult verifyConstructionInvariants(Optional loc, - MLIRContext *context, + static LogicalResult verifyConstructionInvariants(Location loc, ArrayRef shape, Type elementType); @@ -360,8 +356,7 @@ static UnrankedTensorType getChecked(Type elementType, Location location); /// Verify the construction of a unranked tensor type. - static LogicalResult verifyConstructionInvariants(Optional loc, - MLIRContext *context, + static LogicalResult verifyConstructionInvariants(Location loc, Type elementType); ArrayRef getShape() const { return llvm::None; } @@ -505,8 +500,7 @@ Location location); /// Verify the construction of a unranked memref type. - static LogicalResult verifyConstructionInvariants(Optional loc, - MLIRContext *context, + static LogicalResult verifyConstructionInvariants(Location loc, Type elementType, unsigned memorySpace); diff --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h --- a/mlir/include/mlir/IR/StorageUniquerSupport.h +++ b/mlir/include/mlir/IR/StorageUniquerSupport.h @@ -18,10 +18,15 @@ #include "mlir/Support/StorageUniquer.h" namespace mlir { -class Location; +class AttributeStorage; class MLIRContext; namespace detail { +/// Utility method to generate a raw default location for use when checking the +/// construction invariants of a storage object. This is defined out-of-line to +/// avoid the need to include Location.h. +const AttributeStorage *generateUnknownStorageLocation(MLIRContext *ctx); + /// Utility class for implementing users of storage classes uniqued by a /// StorageUniquer. Clients are not expected to interact with this class /// directly. @@ -53,21 +58,20 @@ template static ConcreteT get(MLIRContext *ctx, unsigned kind, Args... args) { // Ensure that the invariants are correct for construction. - assert(succeeded( - ConcreteT::verifyConstructionInvariants(llvm::None, ctx, args...))); + assert(succeeded(ConcreteT::verifyConstructionInvariants( + generateUnknownStorageLocation(ctx), args...))); return UniquerT::template get(ctx, kind, args...); } /// Get or create a new ConcreteT instance within the ctx, defined at /// the given, potentially unknown, location. If the arguments provided are /// invalid then emit errors and return a null object. - template - static ConcreteT getChecked(const Location &loc, MLIRContext *ctx, - unsigned kind, Args... args) { + template + static ConcreteT getChecked(LocationT loc, unsigned kind, Args... args) { // If the construction invariants fail then we return a null attribute. - if (failed(ConcreteT::verifyConstructionInvariants(loc, ctx, args...))) + if (failed(ConcreteT::verifyConstructionInvariants(loc, args...))) return ConcreteT(); - return UniquerT::template get(ctx, kind, args...); + return UniquerT::template get(loc.getContext(), kind, args...); } /// Default implementation that just returns success. diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h --- a/mlir/include/mlir/IR/Types.h +++ b/mlir/include/mlir/IR/Types.h @@ -46,10 +46,8 @@ /// current type. Used for isa/dyn_cast casting functionality. /// /// * Optional: -/// - static LogicalResult verifyConstructionInvariants( -/// Optional loc, -/// MLIRContext *context, -/// Args... args) +/// - static LogicalResult verifyConstructionInvariants(Location loc, +/// Args... args) /// * This method is invoked when calling the 'TypeBase::get/getChecked' /// methods to ensure that the arguments passed in are valid to construct /// a type instance with. @@ -238,8 +236,7 @@ StringRef getTypeData() const; /// Verify the construction of an opaque type. - static LogicalResult verifyConstructionInvariants(Optional loc, - MLIRContext *context, + static LogicalResult verifyConstructionInvariants(Location loc, Identifier dialect, StringRef typeData); diff --git a/mlir/lib/Dialect/QuantOps/IR/QuantTypes.cpp b/mlir/lib/Dialect/QuantOps/IR/QuantTypes.cpp --- a/mlir/lib/Dialect/QuantOps/IR/QuantTypes.cpp +++ b/mlir/lib/Dialect/QuantOps/IR/QuantTypes.cpp @@ -24,20 +24,19 @@ } LogicalResult QuantizedType::verifyConstructionInvariants( - Optional loc, MLIRContext *context, unsigned flags, - Type storageType, Type expressedType, int64_t storageTypeMin, - int64_t storageTypeMax) { + Location loc, unsigned flags, Type storageType, Type expressedType, + int64_t storageTypeMin, int64_t storageTypeMax) { // Verify that the storage type is integral. // This restriction may be lifted at some point in favor of using bf16 // or f16 as exact representations on hardware where that is advantageous. auto intStorageType = storageType.dyn_cast(); if (!intStorageType) - return emitOptionalError(loc, "storage type must be integral"); + return emitError(loc, "storage type must be integral"); unsigned integralWidth = intStorageType.getWidth(); // Verify storage width. if (integralWidth == 0 || integralWidth > MaxStorageBits) - return emitOptionalError(loc, "illegal storage type size: ", integralWidth); + return emitError(loc, "illegal storage type size: ") << integralWidth; // Verify storageTypeMin and storageTypeMax. bool isSigned = @@ -49,8 +48,8 @@ if (storageTypeMax - storageTypeMin <= 0 || storageTypeMin < defaultIntegerMin || storageTypeMax > defaultIntegerMax) { - return emitOptionalError(loc, "illegal storage min and storage max: (", - storageTypeMin, ":", storageTypeMax, ")"); + return emitError(loc, "illegal storage min and storage max: (") + << storageTypeMin << ":" << storageTypeMax << ")"; } return success(); } @@ -209,17 +208,15 @@ int64_t storageTypeMin, int64_t storageTypeMax, Location location) { - return Base::getChecked(location, storageType.getContext(), - QuantizationTypes::Any, flags, storageType, + return Base::getChecked(location, QuantizationTypes::Any, flags, storageType, expressedType, storageTypeMin, storageTypeMax); } LogicalResult AnyQuantizedType::verifyConstructionInvariants( - Optional loc, MLIRContext *context, unsigned flags, - Type storageType, Type expressedType, int64_t storageTypeMin, - int64_t storageTypeMax) { + Location loc, unsigned flags, Type storageType, Type expressedType, + int64_t storageTypeMin, int64_t storageTypeMax) { if (failed(QuantizedType::verifyConstructionInvariants( - loc, context, flags, storageType, expressedType, storageTypeMin, + loc, flags, storageType, expressedType, storageTypeMin, storageTypeMax))) { return failure(); } @@ -228,7 +225,7 @@ // If this restriction is ever eliminated, the parser/printer must be // extended. if (expressedType && !expressedType.isa()) - return emitOptionalError(loc, "expressed type must be floating point"); + return emitError(loc, "expressed type must be floating point"); return success(); } @@ -249,18 +246,17 @@ Type expressedType, double scale, int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax, Location location) { - return Base::getChecked(location, storageType.getContext(), - QuantizationTypes::UniformQuantized, flags, + return Base::getChecked(location, QuantizationTypes::UniformQuantized, flags, storageType, expressedType, scale, zeroPoint, storageTypeMin, storageTypeMax); } LogicalResult UniformQuantizedType::verifyConstructionInvariants( - Optional loc, MLIRContext *context, unsigned flags, - Type storageType, Type expressedType, double scale, int64_t zeroPoint, - int64_t storageTypeMin, int64_t storageTypeMax) { + Location loc, unsigned flags, Type storageType, Type expressedType, + double scale, int64_t zeroPoint, int64_t storageTypeMin, + int64_t storageTypeMax) { if (failed(QuantizedType::verifyConstructionInvariants( - loc, context, flags, storageType, expressedType, storageTypeMin, + loc, flags, storageType, expressedType, storageTypeMin, storageTypeMax))) { return failure(); } @@ -268,18 +264,17 @@ // Uniform quantization requires fully expressed parameters, including // expressed type. if (!expressedType) - return emitOptionalError(loc, - "uniform quantization requires expressed type"); + return emitError(loc, "uniform quantization requires expressed type"); // Verify that the expressed type is floating point. // If this restriction is ever eliminated, the parser/printer must be // extended. if (!expressedType.isa()) - return emitOptionalError(loc, "expressed type must be floating point"); + return emitError(loc, "expressed type must be floating point"); // Verify scale. if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale)) - return emitOptionalError(loc, "illegal scale: ", scale); + return emitError(loc, "illegal scale: ") << scale; return success(); } @@ -306,19 +301,18 @@ ArrayRef scales, ArrayRef zeroPoints, int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax, Location location) { - return Base::getChecked(location, storageType.getContext(), - QuantizationTypes::UniformQuantizedPerAxis, flags, - storageType, expressedType, scales, zeroPoints, + return Base::getChecked(location, QuantizationTypes::UniformQuantizedPerAxis, + flags, storageType, expressedType, scales, zeroPoints, quantizedDimension, storageTypeMin, storageTypeMax); } LogicalResult UniformQuantizedPerAxisType::verifyConstructionInvariants( - Optional loc, MLIRContext *context, unsigned flags, - Type storageType, Type expressedType, ArrayRef scales, - ArrayRef zeroPoints, int32_t quantizedDimension, - int64_t storageTypeMin, int64_t storageTypeMax) { + Location loc, unsigned flags, Type storageType, Type expressedType, + ArrayRef scales, ArrayRef zeroPoints, + int32_t quantizedDimension, int64_t storageTypeMin, + int64_t storageTypeMax) { if (failed(QuantizedType::verifyConstructionInvariants( - loc, context, flags, storageType, expressedType, storageTypeMin, + loc, flags, storageType, expressedType, storageTypeMin, storageTypeMax))) { return failure(); } @@ -326,24 +320,23 @@ // Uniform quantization requires fully expressed parameters, including // expressed type. if (!expressedType) - return emitOptionalError(loc, - "uniform quantization requires expressed type"); + return emitError(loc, "uniform quantization requires expressed type"); // Verify that the expressed type is floating point. // If this restriction is ever eliminated, the parser/printer must be // extended. if (!expressedType.isa()) - return emitOptionalError(loc, "expressed type must be floating point"); + return emitError(loc, "expressed type must be floating point"); // Ensure that the number of scales and zeroPoints match. if (scales.size() != zeroPoints.size()) - return emitOptionalError(loc, "illegal number of scales and zeroPoints: ", - scales.size(), ", ", zeroPoints.size()); + return emitError(loc, "illegal number of scales and zeroPoints: ") + << scales.size() << ", " << zeroPoints.size(); // Verify scale. for (double scale : scales) { if (scale <= 0.0 || std::isinf(scale) || std::isnan(scale)) - return emitOptionalError(loc, "illegal scale: ", scale); + return emitError(loc, "illegal scale: ") << scale; } return success(); diff --git a/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp b/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp --- a/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp +++ b/mlir/lib/Dialect/SPIRV/TargetAndABI.cpp @@ -103,10 +103,10 @@ } LogicalResult spirv::TargetEnvAttr::verifyConstructionInvariants( - Optional loc, MLIRContext *context, IntegerAttr version, - ArrayAttr extensions, ArrayAttr capabilities, DictionaryAttr limits) { + Location loc, IntegerAttr version, ArrayAttr extensions, + ArrayAttr capabilities, DictionaryAttr limits) { if (!version.getType().isInteger(32)) - return emitOptionalError(loc, "expected 32-bit integer for version"); + return emitError(loc, "expected 32-bit integer for version"); if (!llvm::all_of(extensions.getValue(), [](Attribute attr) { if (auto strAttr = attr.dyn_cast()) @@ -114,7 +114,7 @@ return true; return false; })) - return emitOptionalError(loc, "unknown extension in extension list"); + return emitError(loc, "unknown extension in extension list"); if (!llvm::all_of(capabilities.getValue(), [](Attribute attr) { if (auto intAttr = attr.dyn_cast()) @@ -122,11 +122,10 @@ return true; return false; })) - return emitOptionalError(loc, "unknown capability in capability list"); + return emitError(loc, "unknown capability in capability list"); if (!limits.isa()) - return emitOptionalError(loc, - "expected spirv::ResourceLimitsAttr for limits"); + return emitError(loc, "expected spirv::ResourceLimitsAttr for limits"); return success(); } diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -182,8 +182,7 @@ } FloatAttr FloatAttr::getChecked(Type type, double value, Location loc) { - return Base::getChecked(loc, type.getContext(), StandardAttributes::Float, - type, value); + return Base::getChecked(loc, StandardAttributes::Float, type, value); } FloatAttr FloatAttr::get(Type type, const APFloat &value) { @@ -191,8 +190,7 @@ } FloatAttr FloatAttr::getChecked(Type type, const APFloat &value, Location loc) { - return Base::getChecked(loc, type.getContext(), StandardAttributes::Float, - type, value); + return Base::getChecked(loc, StandardAttributes::Float, type, value); } APFloat FloatAttr::getValue() const { return getImpl()->getValue(); } @@ -210,22 +208,18 @@ } /// Verify construction invariants. -static LogicalResult verifyFloatTypeInvariants(Optional loc, - Type type) { +static LogicalResult verifyFloatTypeInvariants(Location loc, Type type) { if (!type.isa()) - return emitOptionalError(loc, "expected floating point type"); + return emitError(loc, "expected floating point type"); return success(); } -LogicalResult FloatAttr::verifyConstructionInvariants(Optional loc, - MLIRContext *ctx, - Type type, double value) { +LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type, + double value) { return verifyFloatTypeInvariants(loc, type); } -LogicalResult FloatAttr::verifyConstructionInvariants(Optional loc, - MLIRContext *ctx, - Type type, +LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type, const APFloat &value) { // Verify that the type is correct. if (failed(verifyFloatTypeInvariants(loc, type))) @@ -233,7 +227,7 @@ // Verify that the type semantics match that of the value. if (&type.cast().getFloatSemantics() != &value.getSemantics()) { - return emitOptionalError( + return emitError( loc, "FloatAttr type doesn't match the type implied by its value"); } return success(); @@ -286,31 +280,26 @@ int64_t IntegerAttr::getInt() const { return getValue().getSExtValue(); } -static LogicalResult verifyIntegerTypeInvariants(Optional loc, - Type type) { +static LogicalResult verifyIntegerTypeInvariants(Location loc, Type type) { if (type.isa() || type.isa()) return success(); - return emitOptionalError(loc, "expected integer or index type"); + return emitError(loc, "expected integer or index type"); } -LogicalResult IntegerAttr::verifyConstructionInvariants(Optional loc, - MLIRContext *ctx, - Type type, +LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type, int64_t value) { return verifyIntegerTypeInvariants(loc, type); } -LogicalResult IntegerAttr::verifyConstructionInvariants(Optional loc, - MLIRContext *ctx, - Type type, +LogicalResult IntegerAttr::verifyConstructionInvariants(Location loc, Type type, const APInt &value) { if (failed(verifyIntegerTypeInvariants(loc, type))) return failure(); if (auto integerType = type.dyn_cast()) if (integerType.getWidth() != value.getBitWidth()) - return emitOptionalError( - loc, "integer type bit width (", integerType.getWidth(), - ") doesn't match value bit width (", value.getBitWidth(), ")"); + return emitError(loc, "integer type bit width (") + << integerType.getWidth() << ") doesn't match value bit width (" + << value.getBitWidth() << ")"; return success(); } @@ -337,8 +326,8 @@ OpaqueAttr OpaqueAttr::getChecked(Identifier dialect, StringRef attrData, Type type, Location location) { - return Base::getChecked(location, type.getContext(), - StandardAttributes::Opaque, dialect, attrData, type); + return Base::getChecked(location, StandardAttributes::Opaque, dialect, + attrData, type); } /// Returns the dialect namespace of the opaque attribute. @@ -350,13 +339,12 @@ StringRef OpaqueAttr::getAttrData() const { return getImpl()->attrData; } /// Verify the construction of an opaque attribute. -LogicalResult OpaqueAttr::verifyConstructionInvariants(Optional loc, - MLIRContext *context, +LogicalResult OpaqueAttr::verifyConstructionInvariants(Location loc, Identifier dialect, StringRef attrData, Type type) { if (!Dialect::isValidNamespace(dialect.strref())) - return emitOptionalError(loc, "invalid dialect namespace '", dialect, "'"); + return emitError(loc, "invalid dialect namespace '") << dialect << "'"; return success(); } diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -518,7 +518,7 @@ Location location) { if (auto cached = getCachedIntegerType(width, context)) return cached; - return Base::getChecked(location, context, StandardTypes::Integer, width); + return Base::getChecked(location, StandardTypes::Integer, width); } /// Get an instance of the NoneType. @@ -639,3 +639,16 @@ llvm::sys::SmartScopedWriter affineLock(impl.affineMutex); return constructorFn(); } + +//===----------------------------------------------------------------------===// +// StorageUniquerSupport +//===----------------------------------------------------------------------===// + +/// Utility method to generate a default location for use when checking the +/// construction invariants of a storage object. This is defined out-of-line to +/// avoid the need to include Location.h. +const AttributeStorage * +mlir::detail::generateUnknownStorageLocation(MLIRContext *ctx) { + return reinterpret_cast( + ctx->getImpl().unknownLocAttr.getAsOpaquePointer()); +} diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp --- a/mlir/lib/IR/StandardTypes.cpp +++ b/mlir/lib/IR/StandardTypes.cpp @@ -52,12 +52,11 @@ constexpr unsigned IntegerType::kMaxWidth; /// Verify the construction of an integer type. -LogicalResult IntegerType::verifyConstructionInvariants(Optional loc, - MLIRContext *context, +LogicalResult IntegerType::verifyConstructionInvariants(Location loc, unsigned width) { if (width > IntegerType::kMaxWidth) { - return emitOptionalError(loc, "integer bitwidth is limited to ", - IntegerType::kMaxWidth, " bits"); + return emitError(loc) << "integer bitwidth is limited to " + << IntegerType::kMaxWidth << " bits"; } return success(); } @@ -203,24 +202,20 @@ VectorType VectorType::getChecked(ArrayRef shape, Type elementType, Location location) { - return Base::getChecked(location, elementType.getContext(), - StandardTypes::Vector, shape, elementType); + return Base::getChecked(location, StandardTypes::Vector, shape, elementType); } -LogicalResult VectorType::verifyConstructionInvariants(Optional loc, - MLIRContext *context, +LogicalResult VectorType::verifyConstructionInvariants(Location loc, ArrayRef shape, Type elementType) { if (shape.empty()) - return emitOptionalError(loc, - "vector types must have at least one dimension"); + return emitError(loc, "vector types must have at least one dimension"); if (!isValidElementType(elementType)) - return emitOptionalError(loc, "vector elements must be int or float type"); + return emitError(loc, "vector elements must be int or float type"); if (any_of(shape, [](int64_t i) { return i <= 0; })) - return emitOptionalError(loc, - "vector types must have positive constant sizes"); + return emitError(loc, "vector types must have positive constant sizes"); return success(); } @@ -233,11 +228,10 @@ // Check if "elementType" can be an element type of a tensor. Emit errors if // location is not nullptr. Returns failure if check failed. -static inline LogicalResult checkTensorElementType(Optional location, - MLIRContext *context, +static inline LogicalResult checkTensorElementType(Location location, Type elementType) { if (!TensorType::isValidElementType(elementType)) - return emitOptionalError(location, "invalid tensor element type"); + return emitError(location, "invalid tensor element type"); return success(); } @@ -254,18 +248,17 @@ RankedTensorType RankedTensorType::getChecked(ArrayRef shape, Type elementType, Location location) { - return Base::getChecked(location, elementType.getContext(), - StandardTypes::RankedTensor, shape, elementType); + return Base::getChecked(location, StandardTypes::RankedTensor, shape, + elementType); } LogicalResult RankedTensorType::verifyConstructionInvariants( - Optional loc, MLIRContext *context, ArrayRef shape, - Type elementType) { + Location loc, ArrayRef shape, Type elementType) { for (int64_t s : shape) { if (s < -1) - return emitOptionalError(loc, "invalid tensor dimension size"); + return emitError(loc, "invalid tensor dimension size"); } - return checkTensorElementType(loc, context, elementType); + return checkTensorElementType(loc, elementType); } ArrayRef RankedTensorType::getShape() const { @@ -283,13 +276,13 @@ UnrankedTensorType UnrankedTensorType::getChecked(Type elementType, Location location) { - return Base::getChecked(location, elementType.getContext(), - StandardTypes::UnrankedTensor, elementType); + return Base::getChecked(location, StandardTypes::UnrankedTensor, elementType); } -LogicalResult UnrankedTensorType::verifyConstructionInvariants( - Optional loc, MLIRContext *context, Type elementType) { - return checkTensorElementType(loc, context, elementType); +LogicalResult +UnrankedTensorType::verifyConstructionInvariants(Location loc, + Type elementType) { + return checkTensorElementType(loc, elementType); } //===----------------------------------------------------------------------===// @@ -399,8 +392,7 @@ UnrankedMemRefType UnrankedMemRefType::getChecked(Type elementType, unsigned memorySpace, Location location) { - return Base::getChecked(location, elementType.getContext(), - StandardTypes::UnrankedMemRef, elementType, + return Base::getChecked(location, StandardTypes::UnrankedMemRef, elementType, memorySpace); } @@ -408,13 +400,13 @@ return getImpl()->memorySpace; } -LogicalResult UnrankedMemRefType::verifyConstructionInvariants( - Optional loc, MLIRContext *context, Type elementType, - unsigned memorySpace) { +LogicalResult +UnrankedMemRefType::verifyConstructionInvariants(Location loc, Type elementType, + unsigned memorySpace) { // Check that memref is formed from allowed types. if (!elementType.isIntOrFloat() && !elementType.isa() && !elementType.isa()) - return emitOptionalError(*loc, "invalid memref element type"); + return emitError(loc, "invalid memref element type"); return success(); } @@ -621,16 +613,14 @@ } ComplexType ComplexType::getChecked(Type elementType, Location location) { - return Base::getChecked(location, elementType.getContext(), - StandardTypes::Complex, elementType); + return Base::getChecked(location, StandardTypes::Complex, elementType); } /// Verify the construction of an integer type. -LogicalResult ComplexType::verifyConstructionInvariants(Optional loc, - MLIRContext *context, +LogicalResult ComplexType::verifyConstructionInvariants(Location loc, Type elementType) { if (!elementType.isa() && !elementType.isa()) - return emitOptionalError(loc, "invalid element type for complex"); + return emitError(loc, "invalid element type for complex"); return success(); } diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp --- a/mlir/lib/IR/Types.cpp +++ b/mlir/lib/IR/Types.cpp @@ -59,7 +59,7 @@ OpaqueType OpaqueType::getChecked(Identifier dialect, StringRef typeData, MLIRContext *context, Location location) { - return Base::getChecked(location, context, Kind::Opaque, dialect, typeData); + return Base::getChecked(location, Kind::Opaque, dialect, typeData); } /// Returns the dialect namespace of the opaque type. @@ -71,11 +71,10 @@ StringRef OpaqueType::getTypeData() const { return getImpl()->typeData; } /// Verify the construction of an opaque type. -LogicalResult OpaqueType::verifyConstructionInvariants(Optional loc, - MLIRContext *context, +LogicalResult OpaqueType::verifyConstructionInvariants(Location loc, Identifier dialect, StringRef typeData) { if (!Dialect::isValidNamespace(dialect.strref())) - return emitOptionalError(loc, "invalid dialect namespace '", dialect, "'"); + return emitError(loc, "invalid dialect namespace '") << dialect << "'"; return success(); }