diff --git a/mlir/include/mlir/IR/DialectImplementation.h b/mlir/include/mlir/IR/DialectImplementation.h --- a/mlir/include/mlir/IR/DialectImplementation.h +++ b/mlir/include/mlir/IR/DialectImplementation.h @@ -15,6 +15,7 @@ #define MLIR_IR_DIALECTIMPLEMENTATION_H #include "mlir/IR/OpImplementation.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/SMLoc.h" #include "llvm/Support/raw_ostream.h" @@ -341,6 +342,8 @@ /// static-dimension-list ::= (integer `x`)* virtual ParseResult parseDimensionList(SmallVectorImpl &dimensions, bool allowDynamic = true) = 0; + + virtual llvm::SetVector &getStructContext() = 0; }; } // end namespace mlir diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp @@ -180,24 +180,20 @@ // Parsing. //===----------------------------------------------------------------------===// -static LLVMType parseTypeImpl(DialectAsmParser &parser, - llvm::SetVector &stack); +static LLVMType parseTypeImpl(DialectAsmParser &parser); /// Helper to be chained with other parsing functions. -static ParseResult parseTypeImpl(DialectAsmParser &parser, - llvm::SetVector &stack, - LLVMType &result) { - result = parseTypeImpl(parser, stack); +static ParseResult parseTypeImpl(DialectAsmParser &parser, LLVMType &result) { + result = parseTypeImpl(parser); return success(result != nullptr); } /// Parses an LLVM dialect function type. /// llvm-type :: = `func<` llvm-type `(` llvm-type-list `...`? `)>` -static LLVMFunctionType parseFunctionType(DialectAsmParser &parser, - llvm::SetVector &stack) { +static LLVMFunctionType parseFunctionType(DialectAsmParser &parser) { Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation()); LLVMType returnType; - if (parser.parseLess() || parseTypeImpl(parser, stack, returnType) || + if (parser.parseLess() || parseTypeImpl(parser, returnType) || parser.parseLParen()) return LLVMFunctionType(); @@ -219,7 +215,7 @@ /*isVarArg=*/true); } - argTypes.push_back(parseTypeImpl(parser, stack)); + argTypes.push_back(parseTypeImpl(parser)); if (!argTypes.back()) return LLVMFunctionType(); } while (succeeded(parser.parseOptionalComma())); @@ -232,11 +228,10 @@ /// Parses an LLVM dialect pointer type. /// llvm-type ::= `ptr<` llvm-type (`,` integer)? `>` -static LLVMPointerType parsePointerType(DialectAsmParser &parser, - llvm::SetVector &stack) { +static LLVMPointerType parsePointerType(DialectAsmParser &parser) { Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation()); LLVMType elementType; - if (parser.parseLess() || parseTypeImpl(parser, stack, elementType)) + if (parser.parseLess() || parseTypeImpl(parser, elementType)) return LLVMPointerType(); unsigned addressSpace = 0; @@ -251,15 +246,14 @@ /// Parses an LLVM dialect vector type. /// llvm-type ::= `vec<` `? x`? integer `x` llvm-type `>` /// Supports both fixed and scalable vectors. -static LLVMVectorType parseVectorType(DialectAsmParser &parser, - llvm::SetVector &stack) { +static LLVMVectorType parseVectorType(DialectAsmParser &parser) { SmallVector dims; llvm::SMLoc dimPos; LLVMType elementType; Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation()); if (parser.parseLess() || parser.getCurrentLocation(&dimPos) || parser.parseDimensionList(dims, /*allowDynamic=*/true) || - parseTypeImpl(parser, stack, elementType) || parser.parseGreater()) + parseTypeImpl(parser, elementType) || parser.parseGreater()) return LLVMVectorType(); // We parsed a generic dimension list, but vectors only support two forms: @@ -282,15 +276,14 @@ /// Parses an LLVM dialect array type. /// llvm-type ::= `array<` integer `x` llvm-type `>` -static LLVMArrayType parseArrayType(DialectAsmParser &parser, - llvm::SetVector &stack) { +static LLVMArrayType parseArrayType(DialectAsmParser &parser) { SmallVector dims; llvm::SMLoc sizePos; LLVMType elementType; Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation()); if (parser.parseLess() || parser.getCurrentLocation(&sizePos) || parser.parseDimensionList(dims, /*allowDynamic=*/false) || - parseTypeImpl(parser, stack, elementType) || parser.parseGreater()) + parseTypeImpl(parser, elementType) || parser.parseGreater()) return LLVMArrayType(); if (dims.size() != 1) { @@ -307,8 +300,7 @@ static LLVMStructType trySetStructBody(LLVMStructType type, ArrayRef subtypes, bool isPacked, DialectAsmParser &parser, - llvm::SMLoc subtypesLoc, - llvm::SetVector &stack) { + llvm::SMLoc subtypesLoc) { for (LLVMType t : subtypes) { if (!LLVMStructType::isValidElementType(t)) { parser.emitError(subtypesLoc) @@ -322,6 +314,7 @@ std::string currentBody; llvm::raw_string_ostream currentBodyStream(currentBody); + llvm::SetVector stack; printStructTypeBody(currentBodyStream, type, stack); auto diag = parser.emitError(subtypesLoc) << "identified type already used with a different body"; @@ -334,8 +327,7 @@ /// `(` llvm-type-list `)` `>` /// | `struct<` string-literal `>` /// | `struct<` string-literal `, opaque>` -static LLVMStructType parseStructType(DialectAsmParser &parser, - llvm::SetVector &stack) { +static LLVMStructType parseStructType(DialectAsmParser &parser) { Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation()); if (failed(parser.parseLess())) @@ -347,7 +339,7 @@ StringRef name; bool isIdentified = succeeded(parser.parseOptionalString(&name)); if (isIdentified) { - if (stack.count(name)) { + if (parser.getStructContext().count(name)) { if (failed(parser.parseGreater())) return LLVMStructType(); return LLVMStructType::getIdentifiedChecked(loc, name); @@ -384,7 +376,7 @@ if (!isIdentified) return LLVMStructType::getLiteralChecked(loc, {}, isPacked); auto type = LLVMStructType::getIdentifiedChecked(loc, name); - return trySetStructBody(type, {}, isPacked, parser, kwLoc, stack); + return trySetStructBody(type, {}, isPacked, parser, kwLoc); } // Parse subtypes. For identified structs, put the identifier of the struct on @@ -393,13 +385,13 @@ llvm::SMLoc subtypesLoc = parser.getCurrentLocation(); do { if (isIdentified) - stack.insert(name); - LLVMType type = parseTypeImpl(parser, stack); + parser.getStructContext().insert(name); + LLVMType type = parseTypeImpl(parser); if (!type) return LLVMStructType(); subtypes.push_back(type); if (isIdentified) - stack.pop_back(); + parser.getStructContext().pop_back(); } while (succeeded(parser.parseOptionalComma())); if (parser.parseRParen() || parser.parseGreater()) @@ -409,12 +401,11 @@ if (!isIdentified) return LLVMStructType::getLiteralChecked(loc, subtypes, isPacked); auto type = LLVMStructType::getIdentifiedChecked(loc, name); - return trySetStructBody(type, subtypes, isPacked, parser, subtypesLoc, stack); + return trySetStructBody(type, subtypes, isPacked, parser, subtypesLoc); } /// Parses one of the LLVM dialect types. -static LLVMType parseTypeImpl(DialectAsmParser &parser, - llvm::SetVector &stack) { +static LLVMType parseTypeImpl(DialectAsmParser &parser) { // Special case for integers (i[1-9][0-9]*) that are literals rather than // keywords for the parser, so they are not caught by the main dispatch below. // Try parsing it a built-in integer type instead. @@ -453,11 +444,11 @@ .Case("token", [&] { return LLVMTokenType::get(ctx); }) .Case("label", [&] { return LLVMLabelType::get(ctx); }) .Case("metadata", [&] { return LLVMMetadataType::get(ctx); }) - .Case("func", [&] { return parseFunctionType(parser, stack); }) - .Case("ptr", [&] { return parsePointerType(parser, stack); }) - .Case("vec", [&] { return parseVectorType(parser, stack); }) - .Case("array", [&] { return parseArrayType(parser, stack); }) - .Case("struct", [&] { return parseStructType(parser, stack); }) + .Case("func", [&] { return parseFunctionType(parser); }) + .Case("ptr", [&] { return parsePointerType(parser); }) + .Case("vec", [&] { return parseVectorType(parser); }) + .Case("array", [&] { return parseArrayType(parser); }) + .Case("struct", [&] { return parseStructType(parser); }) .Default([&] { parser.emitError(keyLoc) << "unknown LLVM type: " << key; return LLVMType(); @@ -465,6 +456,5 @@ } LLVMType mlir::LLVM::detail::parseType(DialectAsmParser &parser) { - llvm::SetVector stack; - return parseTypeImpl(parser, stack); + return parseTypeImpl(parser); } diff --git a/mlir/lib/Parser/DialectSymbolParser.cpp b/mlir/lib/Parser/DialectSymbolParser.cpp --- a/mlir/lib/Parser/DialectSymbolParser.cpp +++ b/mlir/lib/Parser/DialectSymbolParser.cpp @@ -308,6 +308,10 @@ return parser.parseDimensionListRanked(dimensions, allowDynamic); } + llvm::SetVector &getStructContext() override { + return parser.getStructContext(); + } + OptionalParseResult parseOptionalType(Type &result) override { return parser.parseOptionalType(result); } diff --git a/mlir/lib/Parser/Parser.h b/mlir/lib/Parser/Parser.h --- a/mlir/lib/Parser/Parser.h +++ b/mlir/lib/Parser/Parser.h @@ -177,6 +177,8 @@ // unknown marker. ParseResult parseStrideList(SmallVectorImpl &dimensions); + llvm::SetVector &getStructContext(); + //===--------------------------------------------------------------------===// // Attribute Parsing //===--------------------------------------------------------------------===// diff --git a/mlir/lib/Parser/ParserState.h b/mlir/lib/Parser/ParserState.h --- a/mlir/lib/Parser/ParserState.h +++ b/mlir/lib/Parser/ParserState.h @@ -11,6 +11,7 @@ #include "Lexer.h" #include "mlir/IR/Attributes.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/StringMap.h" namespace mlir { @@ -77,6 +78,8 @@ /// The depth of this parser in the nested parsing stack. size_t parserDepth; + + llvm::SetVector structContext; }; } // end namespace detail diff --git a/mlir/lib/Parser/TypeParser.cpp b/mlir/lib/Parser/TypeParser.cpp --- a/mlir/lib/Parser/TypeParser.cpp +++ b/mlir/lib/Parser/TypeParser.cpp @@ -568,3 +568,7 @@ return failure(); return success(); } + +llvm::SetVector &Parser::getStructContext() { + return state.structContext; +}