diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1056,6 +1056,12 @@ void AliasInitializer::markAliasNonDeferrable(size_t aliasIndex) { auto it = std::next(aliases.begin(), aliasIndex); + + // If already marked non-deferrable stop the recursion. + // All children should already be marked non-deferrable as well. + if (!it->second.canBeDeferred) + return; + it->second.canBeDeferred = false; // Propagate the non-deferrable flag to any child aliases. diff --git a/mlir/test/IR/recursive-type.mlir b/mlir/test/IR/recursive-type.mlir --- a/mlir/test/IR/recursive-type.mlir +++ b/mlir/test/IR/recursive-type.mlir @@ -1,6 +1,8 @@ // RUN: mlir-opt %s -test-recursive-types | FileCheck %s // CHECK: !testrec = !test.test_rec> +// CHECK: ![[$NAME:.*]] = !test.test_rec_alias> +// CHECK: ![[$NAME2:.*]] = !test.test_rec_alias, i32>> // CHECK-LABEL: @roundtrip func.func @roundtrip() { @@ -12,6 +14,16 @@ // into inifinite recursion. // CHECK: !testrec "test.dummy_op_for_roundtrip"() : () -> !test.test_rec> + + // CHECK: () -> ![[$NAME]] + // CHECK: () -> ![[$NAME]] + "test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias> + "test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias> + + // CHECK: () -> ![[$NAME2]] + // CHECK: () -> ![[$NAME2]] + "test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias, i32>> + "test.dummy_op_for_roundtrip"() : () -> !test.test_rec_alias, i32>> return } diff --git a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp --- a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp @@ -218,6 +218,10 @@ return AliasResult::FinalAlias; } } + if (auto recAliasType = dyn_cast(type)) { + os << recAliasType.getName(); + return AliasResult::FinalAlias; + } return AliasResult::NoAlias; } diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td --- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td +++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td @@ -373,4 +373,22 @@ let mnemonic = "i32"; } +def TestRecursiveAlias + : Test_Type<"TestRecursiveAlias", [NativeTypeTrait<"IsMutable">]> { + let mnemonic = "test_rec_alias"; + let storageClass = "TestRecursiveTypeStorage"; + let storageNamespace = "test"; + let genStorageClass = 0; + + let parameters = (ins "llvm::StringRef":$name); + + let hasCustomAssemblyFormat = 1; + + let extraClassDeclaration = [{ + Type getBody() const; + + void setBody(Type type); + }]; +} + #endif // TEST_TYPEDEFS diff --git a/mlir/test/lib/Dialect/Test/TestTypes.h b/mlir/test/lib/Dialect/Test/TestTypes.h --- a/mlir/test/lib/Dialect/Test/TestTypes.h +++ b/mlir/test/lib/Dialect/Test/TestTypes.h @@ -91,9 +91,6 @@ #include "TestTypeInterfaces.h.inc" -#define GET_TYPEDEF_CLASSES -#include "TestTypeDefs.h.inc" - namespace test { /// Storage for simple named recursive types, where the type is identified by @@ -150,4 +147,7 @@ } // namespace test +#define GET_TYPEDEF_CLASSES +#include "TestTypeDefs.h.inc" + #endif // MLIR_TESTTYPES_H diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp --- a/mlir/test/lib/Dialect/Test/TestTypes.cpp +++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp @@ -482,3 +482,54 @@ SetVector stack; printTestType(type, printer, stack); } + +Type TestRecursiveAliasType::getBody() const { return getImpl()->body; } + +void TestRecursiveAliasType::setBody(Type type) { (void)Base::mutate(type); } + +StringRef TestRecursiveAliasType::getName() const { return getImpl()->name; } + +Type TestRecursiveAliasType::parse(AsmParser &parser) { + thread_local static SetVector stack; + + StringRef name; + if (parser.parseLess() || parser.parseKeyword(&name)) + return Type(); + auto rec = TestRecursiveAliasType::get(parser.getContext(), name); + + // If this type already has been parsed above in the stack, expect just the + // name. + if (stack.contains(rec)) { + if (failed(parser.parseGreater())) + return Type(); + return rec; + } + + // Otherwise, parse the body and update the type. + if (failed(parser.parseComma())) + return Type(); + stack.insert(rec); + Type subtype; + if (parser.parseType(subtype)) + return nullptr; + stack.pop_back(); + if (!subtype || failed(parser.parseGreater())) + return Type(); + + rec.setBody(subtype); + + return rec; +} + +void TestRecursiveAliasType::print(AsmPrinter &printer) const { + thread_local static SetVector stack; + + printer << "<" << getName(); + if (!stack.contains(*this)) { + printer << ", "; + stack.insert(*this); + printer << getBody(); + stack.pop_back(); + } + printer << ">"; +}