diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -928,18 +928,29 @@ class OpAsmDialectInterface : public DialectInterface::Base { public: + /// Holds the result of `getAlias` hook call. + enum class AliasResult { + /// The object (type or attribute) is not supported by the hook + /// and an alias was not provided. + NoAlias, + /// An alias was provided, but it might be overriden by other hook. + OverridableAlias, + /// An alias was provided and it should be used + /// (no other hooks will be checked). + FinalAlias + }; + OpAsmDialectInterface(Dialect *dialect) : Base(dialect) {} /// Hooks for getting an alias identifier alias for a given symbol, that is /// not necessarily a part of this dialect. The identifier is used in place of /// the symbol when printing textual IR. These aliases must not contain `.` or - /// end with a numeric digit([0-9]+). Returns success if an alias was - /// provided, failure otherwise. - virtual LogicalResult getAlias(Attribute attr, raw_ostream &os) const { - return failure(); + /// end with a numeric digit([0-9]+). + virtual AliasResult getAlias(Attribute attr, raw_ostream &os) const { + return AliasResult::NoAlias; } - virtual LogicalResult getAlias(Type type, raw_ostream &os) const { - return failure(); + virtual AliasResult getAlias(Type type, raw_ostream &os) const { + return AliasResult::NoAlias; } /// Get a special name to use when printing the given operation. See diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -652,21 +652,28 @@ template LogicalResult AliasInitializer::generateAlias( T symbol, llvm::MapVector> &aliasToSymbol) { - SmallString<16> tempBuffer; + SmallString<32> nameBuffer; for (const auto &interface : interfaces) { - if (failed(interface.getAlias(symbol, aliasOS))) + OpAsmDialectInterface::AliasResult result = + interface.getAlias(symbol, aliasOS); + if (result == OpAsmDialectInterface::AliasResult::NoAlias) continue; - StringRef name = aliasOS.str(); - assert(!name.empty() && "expected valid alias name"); - name = sanitizeIdentifier(name, tempBuffer, /*allowedPunctChars=*/"$_-", - /*allowTrailingDigit=*/false); - name = name.copy(aliasAllocator); - - aliasToSymbol[name].push_back(symbol); - aliasBuffer.clear(); - return success(); + nameBuffer = std::move(aliasBuffer); + assert(!nameBuffer.empty() && "expected valid alias name"); + if (result == OpAsmDialectInterface::AliasResult::FinalAlias) + break; } - return failure(); + + if (nameBuffer.empty()) + return failure(); + + SmallString<16> tempBuffer; + StringRef name = + sanitizeIdentifier(nameBuffer, tempBuffer, /*allowedPunctChars=*/"$_-", + /*allowTrailingDigit=*/false); + name = name.copy(aliasAllocator); + aliasToSymbol[name].push_back(symbol); + return success(); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/BuiltinDialect.cpp b/mlir/lib/IR/BuiltinDialect.cpp --- a/mlir/lib/IR/BuiltinDialect.cpp +++ b/mlir/lib/IR/BuiltinDialect.cpp @@ -33,30 +33,30 @@ struct BuiltinOpAsmDialectInterface : public OpAsmDialectInterface { using OpAsmDialectInterface::OpAsmDialectInterface; - LogicalResult getAlias(Attribute attr, raw_ostream &os) const override { + AliasResult getAlias(Attribute attr, raw_ostream &os) const override { if (attr.isa()) { os << "map"; - return success(); + return AliasResult::OverridableAlias; } if (attr.isa()) { os << "set"; - return success(); + return AliasResult::OverridableAlias; } if (attr.isa()) { os << "loc"; - return success(); + return AliasResult::OverridableAlias; } - return failure(); + return AliasResult::NoAlias; } - LogicalResult getAlias(Type type, raw_ostream &os) const final { + AliasResult getAlias(Type type, raw_ostream &os) const final { if (auto tupleType = type.dyn_cast()) { if (tupleType.size() > 16) { os << "tuple"; - return success(); + return AliasResult::OverridableAlias; } } - return failure(); + return AliasResult::NoAlias; } }; } // end anonymous namespace. diff --git a/mlir/test/IR/print-attr-type-aliases.mlir b/mlir/test/IR/print-attr-type-aliases.mlir --- a/mlir/test/IR/print-attr-type-aliases.mlir +++ b/mlir/test/IR/print-attr-type-aliases.mlir @@ -18,6 +18,9 @@ // CHECK-DAG: !tuple = type tuple "test.op"() {alias_test = "alias_test:large_tuple"} : () -> (tuple) +// CHECK-DAG: !test_tuple = type tuple +"test.op"() {alias_test = "alias_test:large_tuple"} : () -> (tuple) + // CHECK-DAG: #test_encoding = "alias_test:tensor_encoding" // CHECK-DAG: tensor<32xf32, #test_encoding> "test.op"() : () -> tensor<32xf32, "alias_test:tensor_encoding"> diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -51,10 +51,10 @@ struct TestOpAsmInterface : public OpAsmDialectInterface { using OpAsmDialectInterface::OpAsmDialectInterface; - LogicalResult getAlias(Attribute attr, raw_ostream &os) const final { + AliasResult getAlias(Attribute attr, raw_ostream &os) const final { StringAttr strAttr = attr.dyn_cast(); if (!strAttr) - return failure(); + return AliasResult::NoAlias; // Check the contents of the string attribute to see what the test alias // should be named. @@ -70,10 +70,23 @@ .Case("alias_test:tensor_encoding", StringRef("test_encoding")) .Default(llvm::None); if (!aliasName) - return failure(); + return AliasResult::NoAlias; os << *aliasName; - return success(); + return AliasResult::FinalAlias; + } + + AliasResult getAlias(Type type, raw_ostream &os) const final { + if (auto tupleType = type.dyn_cast()) { + if (tupleType.size() > 0 && + llvm::all_of(tupleType.getTypes(), [](Type elemType) { + return elemType.isa(); + })) { + os << "test_tuple"; + return AliasResult::FinalAlias; + } + } + return AliasResult::NoAlias; } void getAsmResultNames(Operation *op,