diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/TypeSwitch.h" @@ -19,8 +20,14 @@ // Printing. //===----------------------------------------------------------------------===// -static void printTypeImpl(llvm::raw_ostream &os, Type type, - llvm::SetVector &stack); +/// If the given type is compatible with the LLVM dialect, prints it using +/// internal functions to avoid getting a verbose `!llvm` prefix. Otherwise +/// prints it as usual. +static void dispatchPrint(DialectAsmPrinter &printer, Type type) { + if (isCompatibleType(type)) + return mlir::LLVM::detail::printType(type, printer); + printer.printType(type); +} /// Returns the keyword to use for the given type. static StringRef getTypeKeyword(Type type) { @@ -48,76 +55,79 @@ }); } -/// Prints the body of a structure type. Uses `stack` to avoid printing -/// recursive structs indefinitely. -static void printStructTypeBody(llvm::raw_ostream &os, LLVMStructType type, - llvm::SetVector &stack) { - if (type.isIdentified() && type.isOpaque()) { - os << "opaque"; - return; - } - - if (type.isPacked()) - os << "packed "; - - // Put the current type on stack to avoid infinite recursion. - os << '('; - if (type.isIdentified()) - stack.insert(type.getName()); - llvm::interleaveComma(type.getBody(), os, [&](Type subtype) { - printTypeImpl(os, subtype, stack); +/// Prints a structure type. Keeps track of known struct names to handle self- +/// or mutually-referring structs without falling into infinite recursion. +static void printStructType(DialectAsmPrinter &printer, LLVMStructType type) { + // This keeps track of the names of identified structure types that are + // currently being printed. Since such types can refer themselves, this + // tracking is necessary to stop the recursion: the current function may be + // called recursively from DialectAsmPrinter::printType after the appropriate + // dispatch. We maintain the invariant of this storage being modified + // exclusively in this function, and at most one name being added per call. + // TODO: consider having such functionality inside DialectAsmPrinter. + thread_local llvm::SetVector knownStructNames; + unsigned stackSize = knownStructNames.size(); + (void)stackSize; + auto guard = llvm::make_scope_exit([&]() { + assert(knownStructNames.size() == stackSize && + "malformed identified stack when printing recursive structs"); }); - if (type.isIdentified()) - stack.pop_back(); - os << ')'; -} -/// Prints a structure type. Uses `stack` to keep track of the identifiers of -/// the structs being printed. Checks if the identifier of a struct is contained -/// in `stack`, i.e. whether a self-reference to a recursive stack is being -/// printed, and only prints the name to avoid infinite recursion. -static void printStructType(llvm::raw_ostream &os, LLVMStructType type, - llvm::SetVector &stack) { - os << "<"; + printer << "<"; if (type.isIdentified()) { - os << '"' << type.getName() << '"'; + printer << '"' << type.getName() << '"'; // If we are printing a reference to one of the enclosing structs, just // print the name and stop to avoid infinitely long output. - if (stack.count(type.getName())) { - os << '>'; + if (knownStructNames.count(type.getName())) { + printer << '>'; return; } - os << ", "; + printer << ", "; + } + + if (type.isIdentified() && type.isOpaque()) { + printer << "opaque>"; + return; } - printStructTypeBody(os, type, stack); - os << '>'; + if (type.isPacked()) + printer << "packed "; + + // Put the current type on stack to avoid infinite recursion. + printer << '('; + if (type.isIdentified()) + knownStructNames.insert(type.getName()); + llvm::interleaveComma(type.getBody(), printer.getStream(), + [&](Type subtype) { dispatchPrint(printer, subtype); }); + if (type.isIdentified()) + knownStructNames.pop_back(); + printer << ')'; + printer << '>'; } /// Prints a type containing a fixed number of elements. template -static void printArrayOrVectorType(llvm::raw_ostream &os, TypeTy type, - llvm::SetVector &stack) { - os << '<' << type.getNumElements() << " x "; - printTypeImpl(os, type.getElementType(), stack); - os << '>'; +static void printArrayOrVectorType(DialectAsmPrinter &printer, TypeTy type) { + printer << '<' << type.getNumElements() << " x "; + dispatchPrint(printer, type.getElementType()); + printer << '>'; } /// Prints a function type. -static void printFunctionType(llvm::raw_ostream &os, LLVMFunctionType funcType, - llvm::SetVector &stack) { - os << '<'; - printTypeImpl(os, funcType.getReturnType(), stack); - os << " ("; - llvm::interleaveComma(funcType.getParams(), os, [&os, &stack](Type subtype) { - printTypeImpl(os, subtype, stack); - }); +static void printFunctionType(DialectAsmPrinter &printer, + LLVMFunctionType funcType) { + printer << '<'; + dispatchPrint(printer, funcType.getReturnType()); + printer << " ("; + llvm::interleaveComma( + funcType.getParams(), printer.getStream(), + [&printer](Type subtype) { dispatchPrint(printer, subtype); }); if (funcType.isVarArg()) { if (funcType.getNumParams() != 0) - os << ", "; - os << "..."; + printer << ", "; + printer << "..."; } - os << ")>"; + printer << ")>"; } /// Prints the given LLVM dialect type recursively. This leverages closedness of @@ -129,75 +139,59 @@ /// struct<"c", (ptr>)>>, /// ptr>)>>)> /// note that "b" is printed twice. -static void printTypeImpl(llvm::raw_ostream &os, Type type, - llvm::SetVector &stack) { +void mlir::LLVM::detail::printType(Type type, DialectAsmPrinter &printer) { if (!type) { - os << "<>"; + printer << "<>"; return; } - os << getTypeKeyword(type); + printer << getTypeKeyword(type); if (auto intType = type.dyn_cast()) { - os << intType.getBitWidth(); + printer << intType.getBitWidth(); return; } if (auto ptrType = type.dyn_cast()) { - os << '<'; - printTypeImpl(os, ptrType.getElementType(), stack); + printer << '<'; + dispatchPrint(printer, ptrType.getElementType()); if (ptrType.getAddressSpace() != 0) - os << ", " << ptrType.getAddressSpace(); - os << '>'; + printer << ", " << ptrType.getAddressSpace(); + printer << '>'; return; } if (auto arrayType = type.dyn_cast()) - return printArrayOrVectorType(os, arrayType, stack); + return printArrayOrVectorType(printer, arrayType); if (auto vectorType = type.dyn_cast()) - return printArrayOrVectorType(os, vectorType, stack); + return printArrayOrVectorType(printer, vectorType); if (auto vectorType = type.dyn_cast()) { - os << "'; + printer << "'; return; } if (auto structType = type.dyn_cast()) - return printStructType(os, structType, stack); + return printStructType(printer, structType); if (auto funcType = type.dyn_cast()) - return printFunctionType(os, funcType, stack); -} - -void mlir::LLVM::detail::printType(Type type, DialectAsmPrinter &printer) { - llvm::SetVector stack; - return printTypeImpl(printer.getStream(), type, stack); + return printFunctionType(printer, funcType); } //===----------------------------------------------------------------------===// // Parsing. //===----------------------------------------------------------------------===// -static Type parseTypeImpl(DialectAsmParser &parser, - llvm::SetVector &stack); - -/// Helper to be chained with other parsing functions. -static ParseResult parseTypeImpl(DialectAsmParser &parser, - llvm::SetVector &stack, - Type &result) { - result = parseTypeImpl(parser, stack); - return success(result != nullptr); -} +static ParseResult dispatchParse(DialectAsmParser &parser, Type &type); /// Parses an LLVM dialect function type. /// llvm-type :: = `func<` llvm-type `(` llvm-type-list `...`? `)>` -static LLVMFunctionType parseFunctionType(DialectAsmParser &parser, - llvm::SetVector &stack) { +static LLVMFunctionType parseFunctionType(DialectAsmParser &parser) { Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation()); Type returnType; - if (parser.parseLess() || parseTypeImpl(parser, stack, returnType) || + if (parser.parseLess() || dispatchParse(parser, returnType) || parser.parseLParen()) return LLVMFunctionType(); @@ -219,9 +213,10 @@ /*isVarArg=*/true); } - argTypes.push_back(parseTypeImpl(parser, stack)); - if (!argTypes.back()) + Type arg; + if (dispatchParse(parser, arg)) return LLVMFunctionType(); + argTypes.push_back(arg); } while (succeeded(parser.parseOptionalComma())); if (parser.parseOptionalRParen() || parser.parseOptionalGreater()) @@ -232,11 +227,10 @@ /// Parses an LLVM dialect pointer type. /// llvm-type ::= `ptr<` llvm-type (`,` integer)? `>` -static LLVMPointerType parsePointerType(DialectAsmParser &parser, - llvm::SetVector &stack) { +static LLVMPointerType parsePointerType(DialectAsmParser &parser) { Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation()); Type elementType; - if (parser.parseLess() || parseTypeImpl(parser, stack, elementType)) + if (parser.parseLess() || dispatchParse(parser, elementType)) return LLVMPointerType(); unsigned addressSpace = 0; @@ -251,15 +245,14 @@ /// Parses an LLVM dialect vector type. /// llvm-type ::= `vec<` `? x`? integer `x` llvm-type `>` /// Supports both fixed and scalable vectors. -static LLVMVectorType parseVectorType(DialectAsmParser &parser, - llvm::SetVector &stack) { +static LLVMVectorType parseVectorType(DialectAsmParser &parser) { SmallVector dims; llvm::SMLoc dimPos; Type elementType; Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation()); if (parser.parseLess() || parser.getCurrentLocation(&dimPos) || parser.parseDimensionList(dims, /*allowDynamic=*/true) || - parseTypeImpl(parser, stack, elementType) || parser.parseGreater()) + dispatchParse(parser, elementType) || parser.parseGreater()) return LLVMVectorType(); // We parsed a generic dimension list, but vectors only support two forms: @@ -282,15 +275,14 @@ /// Parses an LLVM dialect array type. /// llvm-type ::= `array<` integer `x` llvm-type `>` -static LLVMArrayType parseArrayType(DialectAsmParser &parser, - llvm::SetVector &stack) { +static LLVMArrayType parseArrayType(DialectAsmParser &parser) { SmallVector dims; llvm::SMLoc sizePos; Type elementType; Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation()); if (parser.parseLess() || parser.getCurrentLocation(&sizePos) || parser.parseDimensionList(dims, /*allowDynamic=*/false) || - parseTypeImpl(parser, stack, elementType) || parser.parseGreater()) + dispatchParse(parser, elementType) || parser.parseGreater()) return LLVMArrayType(); if (dims.size() != 1) { @@ -302,13 +294,11 @@ } /// Attempts to set the body of an identified structure type. Reports a parsing -/// error at `subtypesLoc` in case of failure, uses `stack` to make sure the -/// types printed in the error message look like they did when parsed. +/// error at `subtypesLoc` in case of failure. static LLVMStructType trySetStructBody(LLVMStructType type, ArrayRef subtypes, bool isPacked, DialectAsmParser &parser, - llvm::SMLoc subtypesLoc, - llvm::SetVector &stack) { + llvm::SMLoc subtypesLoc) { for (Type t : subtypes) { if (!LLVMStructType::isValidElementType(t)) { parser.emitError(subtypesLoc) @@ -320,12 +310,8 @@ if (succeeded(type.setBody(subtypes, isPacked))) return type; - std::string currentBody; - llvm::raw_string_ostream currentBodyStream(currentBody); - printStructTypeBody(currentBodyStream, type, stack); - auto diag = parser.emitError(subtypesLoc) - << "identified type already used with a different body"; - diag.attachNote() << "existing body: " << currentBodyStream.str(); + parser.emitError(subtypesLoc) + << "identified type already used with a different body"; return LLVMStructType(); } @@ -334,8 +320,22 @@ /// `(` llvm-type-list `)` `>` /// | `struct<` string-literal `>` /// | `struct<` string-literal `, opaque>` -static LLVMStructType parseStructType(DialectAsmParser &parser, - llvm::SetVector &stack) { +static LLVMStructType parseStructType(DialectAsmParser &parser) { + // This keeps track of the names of identified structure types that are + // currently being parsed. Since such types can refer themselves, this + // tracking is necessary to stop the recursion: the current function may be + // called recursively from DialectAsmParser::parseType after the appropriate + // dispatch. We maintain the invariant of this storage being modified + // exclusively in this function, and at most one name being added per call. + // TODO: consider having such functionality inside DialectAsmParser. + thread_local llvm::SetVector knownStructNames; + unsigned stackSize = knownStructNames.size(); + (void)stackSize; + auto guard = llvm::make_scope_exit([&]() { + assert(knownStructNames.size() == stackSize && + "malformed identified stack when parsing recursive structs"); + }); + Location loc = parser.getEncodedSourceLoc(parser.getCurrentLocation()); if (failed(parser.parseLess())) @@ -347,7 +347,7 @@ StringRef name; bool isIdentified = succeeded(parser.parseOptionalString(&name)); if (isIdentified) { - if (stack.count(name)) { + if (knownStructNames.count(name)) { if (failed(parser.parseGreater())) return LLVMStructType(); return LLVMStructType::getIdentifiedChecked(loc, name); @@ -384,7 +384,7 @@ if (!isIdentified) return LLVMStructType::getLiteralChecked(loc, {}, isPacked); auto type = LLVMStructType::getIdentifiedChecked(loc, name); - return trySetStructBody(type, {}, isPacked, parser, kwLoc, stack); + return trySetStructBody(type, {}, isPacked, parser, kwLoc); } // Parse subtypes. For identified structs, put the identifier of the struct on @@ -393,13 +393,13 @@ llvm::SMLoc subtypesLoc = parser.getCurrentLocation(); do { if (isIdentified) - stack.insert(name); - Type type = parseTypeImpl(parser, stack); - if (!type) + knownStructNames.insert(name); + Type type; + if (dispatchParse(parser, type)) return LLVMStructType(); subtypes.push_back(type); if (isIdentified) - stack.pop_back(); + knownStructNames.pop_back(); } while (succeeded(parser.parseOptionalComma())); if (parser.parseRParen() || parser.parseGreater()) @@ -409,30 +409,30 @@ if (!isIdentified) return LLVMStructType::getLiteralChecked(loc, subtypes, isPacked); auto type = LLVMStructType::getIdentifiedChecked(loc, name); - return trySetStructBody(type, subtypes, isPacked, parser, subtypesLoc, stack); + return trySetStructBody(type, subtypes, isPacked, parser, subtypesLoc); } -/// Parses one of the LLVM dialect types. -static Type parseTypeImpl(DialectAsmParser &parser, - llvm::SetVector &stack) { - // Special case for integers (i[1-9][0-9]*) that are literals rather than - // keywords for the parser, so they are not caught by the main dispatch below. - // Try parsing it a built-in integer type instead. - Type maybeIntegerType; - MLIRContext *ctx = parser.getBuilder().getContext(); +/// Parses a type appearing inside another LLVM dialect-compatible type. This +/// will try to parse any type in full form (including types with the `!llvm` +/// prefix), and on failure fall back to parsing the short-hand version of the +/// LLVM dialect types without the `!llvm` prefix. +static Type dispatchParse(DialectAsmParser &parser) { + Type type; llvm::SMLoc keyLoc = parser.getCurrentLocation(); Location loc = parser.getEncodedSourceLoc(keyLoc); - OptionalParseResult result = parser.parseOptionalType(maybeIntegerType); - if (result.hasValue()) { - if (failed(*result)) + OptionalParseResult parseResult = parser.parseOptionalType(type); + if (parseResult.hasValue()) { + if (failed(*parseResult)) return Type(); - if (!maybeIntegerType.isSignlessInteger()) { - parser.emitError(keyLoc) << "unexpected type, expected i* or keyword"; - return Type(); - } - return LLVMIntegerType::getChecked( - loc, maybeIntegerType.getIntOrFloatBitWidth()); + // Special case for integers (i[1-9][0-9]*) that are literals rather than + // keywords for the parser, so they are not caught by the main dispatch + // below. Try parsing it a built-in integer type instead. + auto intType = type.dyn_cast(); + if (!intType || !intType.isSignless()) + return type; + + return LLVMIntegerType::getChecked(loc, intType.getWidth()); } // Dispatch to concrete functions. @@ -440,6 +440,7 @@ if (failed(parser.parseKeyword(&key))) return Type(); + MLIRContext *ctx = parser.getBuilder().getContext(); return StringSwitch>(key) .Case("void", [&] { return LLVMVoidType::get(ctx); }) .Case("half", [&] { return LLVMHalfType::get(ctx); }) @@ -453,18 +454,32 @@ .Case("token", [&] { return LLVMTokenType::get(ctx); }) .Case("label", [&] { return LLVMLabelType::get(ctx); }) .Case("metadata", [&] { return LLVMMetadataType::get(ctx); }) - .Case("func", [&] { return parseFunctionType(parser, stack); }) - .Case("ptr", [&] { return parsePointerType(parser, stack); }) - .Case("vec", [&] { return parseVectorType(parser, stack); }) - .Case("array", [&] { return parseArrayType(parser, stack); }) - .Case("struct", [&] { return parseStructType(parser, stack); }) + .Case("func", [&] { return parseFunctionType(parser); }) + .Case("ptr", [&] { return parsePointerType(parser); }) + .Case("vec", [&] { return parseVectorType(parser); }) + .Case("array", [&] { return parseArrayType(parser); }) + .Case("struct", [&] { return parseStructType(parser); }) .Default([&] { parser.emitError(keyLoc) << "unknown LLVM type: " << key; return Type(); })(); } +/// Helper to use in parse lists. +static ParseResult dispatchParse(DialectAsmParser &parser, Type &type) { + type = dispatchParse(parser); + return success(type != nullptr); +} + +/// Parses one of the LLVM dialect types. Type mlir::LLVM::detail::parseType(DialectAsmParser &parser) { - llvm::SetVector stack; - return parseTypeImpl(parser, stack); + llvm::SMLoc loc = parser.getCurrentLocation(); + Type type = dispatchParse(parser); + if (!type) + return type; + if (!isCompatibleType(type)) { + parser.emitError(loc) << "unexpected type, expected i* or keyword"; + return nullptr; + } + return type; } diff --git a/mlir/test/Dialect/LLVMIR/types-invalid.mlir b/mlir/test/Dialect/LLVMIR/types-invalid.mlir --- a/mlir/test/Dialect/LLVMIR/types-invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/types-invalid.mlir @@ -30,8 +30,7 @@ func @repeated_struct_name() { "some.op"() : () -> !llvm.struct<"a", (ptr>)> - // expected-error @+2 {{identified type already used with a different body}} - // expected-note @+1 {{existing body: (ptr>)}} + // expected-error @+1 {{identified type already used with a different body}} "some.op"() : () -> !llvm.struct<"a", (i32)> } @@ -39,8 +38,7 @@ func @repeated_struct_name_packed() { "some.op"() : () -> !llvm.struct<"a", packed (i32)> - // expected-error @+2 {{identified type already used with a different body}} - // expected-note @+1 {{existing body: packed (i32)}} + // expected-error @+1 {{identified type already used with a different body}} "some.op"() : () -> !llvm.struct<"a", (i32)> } @@ -48,8 +46,7 @@ func @repeated_struct_opaque() { "some.op"() : () -> !llvm.struct<"a", opaque> - // expected-error @+2 {{identified type already used with a different body}} - // expected-note @+1 {{existing body: opaque}} + // expected-error @+1 {{identified type already used with a different body}} "some.op"() : () -> !llvm.struct<"a", ()> } @@ -57,8 +54,7 @@ func @repeated_struct_opaque_non_empty() { "some.op"() : () -> !llvm.struct<"a", opaque> - // expected-error @+2 {{identified type already used with a different body}} - // expected-note @+1 {{existing body: opaque}} + // expected-error @+1 {{identified type already used with a different body}} "some.op"() : () -> !llvm.struct<"a", (i32, i32)> } @@ -95,8 +91,7 @@ func @explicitly_opaque_struct() { "some.op"() : () -> !llvm.struct<"a", opaque> - // expected-error @+2 {{identified type already used with a different body}} - // expected-note @+1 {{existing body: opaque}} + // expected-error @+1 {{identified type already used with a different body}} "some.op"() : () -> !llvm.struct<"a", ()> } diff --git a/mlir/test/Dialect/LLVMIR/types.mlir b/mlir/test/Dialect/LLVMIR/types.mlir --- a/mlir/test/Dialect/LLVMIR/types.mlir +++ b/mlir/test/Dialect/LLVMIR/types.mlir @@ -182,3 +182,29 @@ return } +func @verbose() { + // CHECK: !llvm.struct<(i64, struct<(float)>)> + "some.op"() : () -> !llvm.struct<(!llvm.i64, !llvm.struct<(!llvm.float)>)> + return +} + +// ----- + +// Check that type aliases can be used inside LLVM dialect types. Note that +// currently they are _not_ printed back as this would require +// DialectAsmPrinter to have a mechanism for querying the presence and +// usability of an alias outside of its `printType` method. + +!baz = type !llvm.i64 +!qux = type !llvm.struct<(!baz)> + +!rec = type !llvm.struct<"a", (ptr>)> + +// CHECK: aliases +llvm.func @aliases() { + // CHECK: !llvm.struct<(i32, float, struct<(i64)>)> + "some.op"() : () -> !llvm.struct<(i32, float, !qux)> + // CHECK: !llvm.struct<"a", (ptr>)> + "some.op"() : () -> !rec + llvm.return +}