diff --git a/mlir/include/mlir/IR/DialectImplementation.h b/mlir/include/mlir/IR/DialectImplementation.h --- a/mlir/include/mlir/IR/DialectImplementation.h +++ b/mlir/include/mlir/IR/DialectImplementation.h @@ -121,6 +121,10 @@ virtual llvm::SMLoc getNameLoc() const = 0; /// Re-encode the given source location as an MLIR location and return it. + /// Note: This method should only be used when a `Location` is necessary, as + /// the encoding process is not efficient. In other cases a more suitable + /// alternative should be used, such as the `getChecked` methods defined + /// below. virtual Location getEncodedSourceLoc(llvm::SMLoc loc) = 0; /// Returns the full specification of the symbol being parsed. This allows for @@ -163,6 +167,22 @@ return success(); } + /// Invoke the `getChecked` method of the given Attribute or Type class, using + /// the provided location to emit errors in the case of failure. Note that + /// unlike `OpBuilder::getType`, this method does not implicitly insert a + /// context parameter. + template + T getChecked(llvm::SMLoc loc, ParamsT &&...params) { + return T::getChecked([&] { return emitError(loc); }, + std::forward(params)...); + } + /// A variant of `getChecked` that uses the result of `getNameLoc` to emit + /// errors. + template T getChecked(ParamsT &&...params) { + return T::getChecked([&] { return emitError(getNameLoc()); }, + std::forward(params)...); + } + //===--------------------------------------------------------------------===// // Token Parsing //===--------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp @@ -178,7 +178,7 @@ /// Parses an LLVM dialect function type. /// llvm-type :: = `func<` llvm-type `(` llvm-type-list `...`? `)>` static LLVMFunctionType parseFunctionType(DialectAsmParser &parser) { - Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation()); + llvm::SMLoc loc = parser.getCurrentLocation(); Type returnType; if (parser.parseLess() || dispatchParse(parser, returnType) || parser.parseLParen()) @@ -187,8 +187,8 @@ // Function type without arguments. if (succeeded(parser.parseOptionalRParen())) { if (succeeded(parser.parseGreater())) - return LLVMFunctionType::getChecked(loc, returnType, llvm::None, - /*isVarArg=*/false); + return parser.getChecked(loc, returnType, llvm::None, + /*isVarArg=*/false); return LLVMFunctionType(); } @@ -198,8 +198,8 @@ if (succeeded(parser.parseOptionalEllipsis())) { if (parser.parseOptionalRParen() || parser.parseOptionalGreater()) return LLVMFunctionType(); - return LLVMFunctionType::getChecked(loc, returnType, argTypes, - /*isVarArg=*/true); + return parser.getChecked(loc, returnType, argTypes, + /*isVarArg=*/true); } Type arg; @@ -210,14 +210,14 @@ if (parser.parseOptionalRParen() || parser.parseOptionalGreater()) return LLVMFunctionType(); - return LLVMFunctionType::getChecked(loc, returnType, argTypes, - /*isVarArg=*/false); + return parser.getChecked(loc, returnType, argTypes, + /*isVarArg=*/false); } /// Parses an LLVM dialect pointer type. /// llvm-type ::= `ptr<` llvm-type (`,` integer)? `>` static LLVMPointerType parsePointerType(DialectAsmParser &parser) { - Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation()); + llvm::SMLoc loc = parser.getCurrentLocation(); Type elementType; if (parser.parseLess() || dispatchParse(parser, elementType)) return LLVMPointerType(); @@ -228,7 +228,7 @@ return LLVMPointerType(); if (failed(parser.parseGreater())) return LLVMPointerType(); - return LLVMPointerType::getChecked(loc, elementType, addressSpace); + return parser.getChecked(loc, elementType, addressSpace); } /// Parses an LLVM dialect vector type. @@ -238,7 +238,7 @@ SmallVector dims; llvm::SMLoc dimPos, typePos; Type elementType; - Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation()); + llvm::SMLoc loc = parser.getCurrentLocation(); if (parser.parseLess() || parser.getCurrentLocation(&dimPos) || parser.parseDimensionList(dims, /*allowDynamic=*/true) || parser.getCurrentLocation(&typePos) || @@ -259,13 +259,13 @@ bool isScalable = dims.size() == 2; if (isScalable) - return LLVMScalableVectorType::getChecked(loc, elementType, dims[1]); + return parser.getChecked(loc, elementType, dims[1]); if (elementType.isSignlessIntOrFloat()) { parser.emitError(typePos) << "cannot use !llvm.vec for built-in primitives, use 'vector' instead"; return Type(); } - return LLVMFixedVectorType::getChecked(loc, elementType, dims[0]); + return parser.getChecked(loc, elementType, dims[0]); } /// Parses an LLVM dialect array type. @@ -274,7 +274,7 @@ SmallVector dims; llvm::SMLoc sizePos; Type elementType; - Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation()); + llvm::SMLoc loc = parser.getCurrentLocation(); if (parser.parseLess() || parser.getCurrentLocation(&sizePos) || parser.parseDimensionList(dims, /*allowDynamic=*/false) || dispatchParse(parser, elementType) || parser.parseGreater()) @@ -285,7 +285,7 @@ return LLVMArrayType(); } - return LLVMArrayType::getChecked(loc, elementType, dims[0]); + return parser.getChecked(loc, elementType, dims[0]); } /// Attempts to set the body of an identified structure type. Reports a parsing diff --git a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp --- a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp +++ b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp @@ -117,7 +117,7 @@ /// storage-range ::= integer-literal `:` integer-literal /// storage-type ::= (`i` | `u`) integer-literal /// expressed-type-spec ::= `:` `f` integer-literal -static Type parseAnyType(DialectAsmParser &parser, Location loc) { +static Type parseAnyType(DialectAsmParser &parser) { IntegerType storageType; FloatType expressedType; unsigned typeFlags = 0; @@ -155,9 +155,8 @@ return nullptr; } - return AnyQuantizedType::getChecked(loc, typeFlags, storageType, - expressedType, storageTypeMin, - storageTypeMax); + return parser.getChecked( + typeFlags, storageType, expressedType, storageTypeMin, storageTypeMax); } static ParseResult parseQuantParams(DialectAsmParser &parser, double &scale, @@ -192,7 +191,7 @@ /// axis-spec ::= `:` integer-literal /// scale-zero ::= float-literal `:` integer-literal /// scale-zero-list ::= `{` scale-zero (`,` scale-zero)* `}` -static Type parseUniformType(DialectAsmParser &parser, Location loc) { +static Type parseUniformType(DialectAsmParser &parser) { IntegerType storageType; FloatType expressedType; unsigned typeFlags = 0; @@ -279,14 +278,14 @@ if (isPerAxis) { ArrayRef scalesRef(scales.begin(), scales.end()); ArrayRef zeroPointsRef(zeroPoints.begin(), zeroPoints.end()); - return UniformQuantizedPerAxisType::getChecked( - loc, typeFlags, storageType, expressedType, scalesRef, zeroPointsRef, + return parser.getChecked( + typeFlags, storageType, expressedType, scalesRef, zeroPointsRef, quantizedDimension, storageTypeMin, storageTypeMax); } - return UniformQuantizedType::getChecked( - loc, typeFlags, storageType, expressedType, scales.front(), - zeroPoints.front(), storageTypeMin, storageTypeMax); + return parser.getChecked( + typeFlags, storageType, expressedType, scales.front(), zeroPoints.front(), + storageTypeMin, storageTypeMax); } /// Parses an CalibratedQuantizedType. @@ -295,7 +294,7 @@ /// expressed-spec ::= expressed-type `<` calibrated-range `>` /// expressed-type ::= `f` integer-literal /// calibrated-range ::= float-literal `:` float-literal -static Type parseCalibratedType(DialectAsmParser &parser, Location loc) { +static Type parseCalibratedType(DialectAsmParser &parser) { FloatType expressedType; double min; double max; @@ -314,24 +313,22 @@ return nullptr; } - return CalibratedQuantizedType::getChecked(loc, expressedType, min, max); + return parser.getChecked(expressedType, min, max); } /// Parse a type registered to this dialect. Type QuantizationDialect::parseType(DialectAsmParser &parser) const { - Location loc = parser.getEncodedSourceLoc(parser.getNameLoc()); - // All types start with an identifier that we switch on. StringRef typeNameSpelling; if (failed(parser.parseKeyword(&typeNameSpelling))) return nullptr; if (typeNameSpelling == "uniform") - return parseUniformType(parser, loc); + return parseUniformType(parser); if (typeNameSpelling == "any") - return parseAnyType(parser, loc); + return parseAnyType(parser); if (typeNameSpelling == "calibrated") - return parseCalibratedType(parser, loc); + return parseCalibratedType(parser); parser.emitError(parser.getNameLoc(), "unknown quantized type " + typeNameSpelling); diff --git a/mlir/lib/Parser/DialectSymbolParser.cpp b/mlir/lib/Parser/DialectSymbolParser.cpp --- a/mlir/lib/Parser/DialectSymbolParser.cpp +++ b/mlir/lib/Parser/DialectSymbolParser.cpp @@ -524,7 +524,7 @@ // Otherwise, form a new opaque attribute. return OpaqueAttr::getChecked( - getEncodedSourceLocation(loc), + [&] { return emitError(loc); }, Identifier::get(dialectName, state.context), symbolData, attrType ? attrType : NoneType::get(state.context)); }); @@ -563,7 +563,7 @@ // Otherwise, form a new opaque type. return OpaqueType::getChecked( - getEncodedSourceLocation(loc), + [&] { return emitError(loc); }, Identifier::get(dialectName, state.context), symbolData); }); }