diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp --- a/mlir/lib/AsmParser/AttributeParser.cpp +++ b/mlir/lib/AsmParser/AttributeParser.cpp @@ -49,7 +49,7 @@ /// | distinct-attribute /// | extended-attribute /// -Attribute Parser::parseAttribute(Type type) { +Attribute Parser::parseAttribute(Type type, bool isAliasDef) { switch (getToken().getKind()) { // Parse an AffineMap or IntegerSet attribute. case Token::kw_affine_map: { @@ -157,8 +157,13 @@ return parseStridedLayoutAttr(); // Parse a distinct attribute. - case Token::kw_distinct: - return parseDistinctAttr(type); + case Token::kw_distinct: { + if (isAliasDef) + return parseDistinctAttr(type); + + return (emitWrongTokenError("distinct attribute cannot be defined inline"), + nullptr); + } // Parse a string attribute. case Token::string: { @@ -1227,45 +1232,14 @@ /// Attribute Parser::parseDistinctAttr(Type type) { consumeToken(Token::kw_distinct); - if (parseToken(Token::l_square, "expected '[' after 'distinct'")) - return {}; - - // Parse the distinct integer identifier. - Token token = getToken(); - if (parseToken(Token::integer, "expected distinct Id")) - return {}; - std::optional value = token.getUInt64IntegerValue(); - if (!value) { - emitError("expected an unsigned 64-bit integer"); - return {}; - } - - // Parse the referenced attribute. - if (parseToken(Token::r_square, "expected ']' to close distinct Id") || - parseToken(Token::less, "expected '<' after distinct Id")) + if (parseToken(Token::less, "expected '<' after 'distinct'")) return {}; Attribute referencedAttr = parseAttribute(type); if (!referencedAttr) { emitError("expected attribute"); return {}; } - - // Add the distinct attribute to the parser state, if it has not been parsed - // before. Otherwise, check if the parsed reference attribute matches the one - // found in the parser state. - DenseMap &distinctAttrs = - state.symbols.distinctAttributes; - auto it = distinctAttrs.find(*value); - if (it == distinctAttrs.end()) { - DistinctAttr distinctAttr = DistinctAttr::create(referencedAttr); - it = distinctAttrs.try_emplace(*value, distinctAttr).first; - } else if (it->getSecond().getReferencedAttr() != referencedAttr) { - emitError("referenced attribute does not match previous definition"); - return {}; - } - if (parseToken(Token::greater, "expected '>' to close distinct attribute")) return {}; - - return it->getSecond(); + return DistinctAttr::create(referencedAttr); } diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h --- a/mlir/lib/AsmParser/Parser.h +++ b/mlir/lib/AsmParser/Parser.h @@ -223,7 +223,7 @@ //===--------------------------------------------------------------------===// /// Parse an arbitrary attribute with an optional type. - Attribute parseAttribute(Type type = {}); + Attribute parseAttribute(Type type = {}, bool isAliasDef = false); /// Parse an optional attribute with the provided type. OptionalParseResult parseOptionalAttribute(Attribute &attribute, diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp --- a/mlir/lib/AsmParser/Parser.cpp +++ b/mlir/lib/AsmParser/Parser.cpp @@ -2534,7 +2534,7 @@ return failure(); // Parse the attribute value. - Attribute attr = parseAttribute(); + Attribute attr = parseAttribute(/*type=*/{}, /*isAliasDef=*/true); if (!attr) return failure(); diff --git a/mlir/lib/AsmParser/ParserState.h b/mlir/lib/AsmParser/ParserState.h --- a/mlir/lib/AsmParser/ParserState.h +++ b/mlir/lib/AsmParser/ParserState.h @@ -36,9 +36,6 @@ DenseMap>> dialectResources; - - /// A map from unique integer identifier to DistinctAttr. - DenseMap distinctAttributes; }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -16,6 +16,7 @@ #include "mlir/IR/AsmState.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" @@ -1002,6 +1003,18 @@ DummyAliasOperationPrinter aliasPrinter(printerFlags, *this); aliasPrinter.printCustomOrGenericOp(op); + // Drop all aliases except for distinct attribute aliases if the operation is + // not a top-level operation or if the use local scope flag is set. + if (printerFlags.shouldUseLocalScope() || op->getParentOp()) { + aliases.remove_if([](const auto &aliasIt) { + const auto &[opaqueSymbol, alias] = aliasIt; + if (alias.isType) + return true; + Attribute attr = Attribute::getFromOpaquePointer(opaqueSymbol); + return !isa(attr); + }); + } + // Initialize the aliases. initializeAliases(aliases, attrTypeToAlias); } @@ -1152,8 +1165,10 @@ auto filterFn = [=](const auto &aliasIt) { return aliasIt.second.canBeDeferred() == isDeferred; }; - for (auto &[opaqueSymbol, alias] : - llvm::make_filter_range(attrTypeToAlias, filterFn)) { + auto aliasRange = llvm::make_filter_range(attrTypeToAlias, filterFn); + if (isDeferred && !aliasRange.empty()) + p.getStream() << newLine; + for (auto &[opaqueSymbol, alias] : aliasRange) { alias.print(p.getStream()); p.getStream() << " = "; @@ -1606,31 +1621,6 @@ return name; } -//===----------------------------------------------------------------------===// -// DistinctState -//===----------------------------------------------------------------------===// - -namespace { -/// This class manages the state for distinct attributes. -class DistinctState { -public: - /// Returns a unique identifier for the given distinct attribute. - uint64_t getId(DistinctAttr distinctAttr); - -private: - uint64_t distinctCounter = 0; - DenseMap distinctAttrMap; -}; -} // namespace - -uint64_t DistinctState::getId(DistinctAttr distinctAttr) { - auto [it, inserted] = - distinctAttrMap.try_emplace(distinctAttr, distinctCounter); - if (inserted) - distinctCounter++; - return it->getSecond(); -} - //===----------------------------------------------------------------------===// // Resources //===----------------------------------------------------------------------===// @@ -1742,9 +1732,6 @@ /// Get the state used for SSA names. SSANameState &getSSANameState() { return nameState; } - /// Get the state used for distinct attribute identifiers. - DistinctState &getDistinctState() { return distinctState; } - /// Return the dialects within the context that implement /// OpAsmDialectInterface. DialectInterfaceCollection &getDialectInterfaces() { @@ -1788,9 +1775,6 @@ /// The state used for SSA value names. SSANameState nameState; - /// The state used for distinct attribute identifiers. - DistinctState distinctState; - /// Flags that control op output. OpPrintingFlags printerFlags; @@ -2140,7 +2124,7 @@ os << "unit"; return; } else if (auto distinctAttr = llvm::dyn_cast(attr)) { - os << "distinct[" << state.getDistinctState().getId(distinctAttr) << "]<"; + os << "distinct<"; printAttribute(distinctAttr.getReferencedAttr()); os << '>'; return; @@ -2945,8 +2929,9 @@ explicit OperationPrinter(raw_ostream &os, AsmStateImpl &state) : Impl(os, state), OpAsmPrinter(static_cast(*this)) {} - /// Print the given top-level operation. - void printTopLevelOperation(Operation *op); + /// Print the given operation, including its left-hand side and its right-hand + /// side, with its indent, location, and required aliases. + void printFullOpWithIdentLocAndAliases(Operation *op); /// Print the given operation, including its left-hand side and its right-hand /// side, with its indent and location. @@ -3132,17 +3117,15 @@ }; } // namespace -void OperationPrinter::printTopLevelOperation(Operation *op) { +void OperationPrinter::printFullOpWithIdentLocAndAliases(Operation *op) { // Output the aliases at the top level that can't be deferred. state.getAliasState().printNonDeferredAliases(*this, newLine); // Print the module. printFullOpWithIndentAndLoc(op); - os << newLine; // Output the aliases at the top level that can be deferred. state.getAliasState().printDeferredAliases(*this, newLine); - // Output any file level metadata. printFileMetadataDictionary(op); } @@ -3753,12 +3736,8 @@ } void Operation::print(raw_ostream &os, AsmState &state) { OperationPrinter printer(os, state.getImpl()); - if (!getParent() && !state.getPrinterFlags().shouldUseLocalScope()) { - state.getImpl().initializeAliases(this); - printer.printTopLevelOperation(this); - } else { - printer.printFullOpWithIndentAndLoc(this); - } + state.getImpl().initializeAliases(this); + printer.printFullOpWithIdentLocAndAliases(this); } void Operation::dump() { diff --git a/mlir/test/Dialect/Builtin/Bytecode/attrs.mlir b/mlir/test/Dialect/Builtin/Bytecode/attrs.mlir --- a/mlir/test/Dialect/Builtin/Bytecode/attrs.mlir +++ b/mlir/test/Dialect/Builtin/Bytecode/attrs.mlir @@ -3,6 +3,26 @@ // Bytecode currently does not support big-endian platforms // UNSUPPORTED: target=s390x-{{.*}} + +//===----------------------------------------------------------------------===// +// DistinctAttr +//===----------------------------------------------------------------------===// + +// CHECK: #[[$DIST0:.*]] = distinct<42 : i32> +// CHECK: #[[$DIST1:.*]] = distinct<42 : i32> +#distinct = distinct<42 : i32> +#distinct1 = distinct<42 : i32> + +// CHECK-LABEL: @TestDistinct +module @TestDistinct attributes { + // CHECK: bytecode.distinct = #[[$DIST0]] + // CHECK: bytecode.distinct2 = #[[$DIST0]] + // CHECK: bytecode.distinct3 = #[[$DIST1]] + bytecode.distinct = #distinct, + bytecode.distinct2 = #distinct, + bytecode.distinct3 = #distinct1 +} {} + //===----------------------------------------------------------------------===// // ArrayAttr //===----------------------------------------------------------------------===// @@ -128,20 +148,6 @@ bytecode.type = i178 } {} -//===----------------------------------------------------------------------===// -// DistinctAttr -//===----------------------------------------------------------------------===// - -// CHECK-LABEL: @TestDistinct -module @TestDistinct attributes { - // CHECK: bytecode.distinct = distinct[0]<42 : i32> - // CHECK: bytecode.distinct2 = distinct[0]<42 : i32> - // CHECK: bytecode.distinct3 = distinct[1]<42 : i32> - bytecode.distinct = distinct[0]<42 : i32>, - bytecode.distinct2 = distinct[0]<42 : i32>, - bytecode.distinct3 = distinct[1]<42 : i32> -} {} - //===----------------------------------------------------------------------===// // CallSiteLoc //===----------------------------------------------------------------------===// diff --git a/mlir/test/IR/distinct-attr.mlir b/mlir/test/IR/distinct-attr.mlir --- a/mlir/test/IR/distinct-attr.mlir +++ b/mlir/test/IR/distinct-attr.mlir @@ -1,22 +1,20 @@ // RUN: mlir-opt -allow-unregistered-dialect %s | FileCheck %s -// RUN: mlir-opt -allow-unregistered-dialect -mlir-print-local-scope %s | FileCheck %s --check-prefix=CHECK-GENERIC -// CHECK: #[[DISTINCT0:.*]] = distinct[0]<42 : i32> -// CHECK: #[[DISTINCT1:.*]] = distinct[1]> -// CHECK: #[[DISTINCT2:.*]] = distinct[2]<42 : i32> +// CHECK: #[[DISTINCT0:.*]] = distinct<42 : i32> +#distinct = distinct<42 : i32> +// CHECK: #[[DISTINCT1:.*]] = distinct> +#distinct1 = distinct> +// CHECK: #[[DISTINCT2:.*]] = distinct<42 : i32> +#distinct2 = distinct<42 : i32> // CHECK: distinct.attr = #[[DISTINCT0]] -// CHECK-GENERIC: distinct.attr = distinct[0]<42 : i32> -"foo.op"() {distinct.attr = distinct[0]<42 : i32>} : () -> () +"foo.op"() {distinct.attr = #distinct} : () -> () // CHECK: distinct.attr = #[[DISTINCT1]] -// CHECK-GENERIC: distinct.attr = distinct[1]> -"foo.op"() {distinct.attr = distinct[1]>} : () -> () +"foo.op"() {distinct.attr = #distinct1} : () -> () // CHECK: distinct.attr = #[[DISTINCT0]] -// CHECK-GENERIC: distinct.attr = distinct[0]<42 : i32> -"foo.op"() {distinct.attr = distinct[0]<42 : i32>} : () -> () +"foo.op"() {distinct.attr = #distinct} : () -> () // CHECK: distinct.attr = #[[DISTINCT2]] -// CHECK-GENERIC: distinct.attr = distinct[2]<42 : i32> -"foo.op"() {distinct.attr = distinct[42]<42 : i32>} : () -> () +"foo.op"() {distinct.attr = #distinct2} : () -> () diff --git a/mlir/test/IR/invalid-builtin-attributes.mlir b/mlir/test/IR/invalid-builtin-attributes.mlir --- a/mlir/test/IR/invalid-builtin-attributes.mlir +++ b/mlir/test/IR/invalid-builtin-attributes.mlir @@ -549,41 +549,21 @@ // ----- -// expected-error@below {{expected '[' after 'distinct'}} -#attr = distinct< - -// ----- - -// expected-error@below {{expected distinct Id}} -#attr = distinct[i8 - -// ----- - -// expected-error@below {{expected an unsigned 64-bit integer}} -#attr = distinct[0xAAAABBBBEEEEFFFF1] - -// ----- - -// expected-error@below {{expected ']' to close distinct Id}} -#attr = distinct[8) - -// ----- - -// expected-error@below {{expected '<' after distinct Id}} -#attr = distinct[8]( +// expected-error@below {{expected '<' after 'distinct'}} +#attr = distinct[ // ----- // expected-error@below {{expected attribute}} -#attr = distinct[8]' to close distinct attribute}} -#attr = distinct[8]<@foo] +#attr = distinct<@foo] // ----- -#attr = distinct[0]<42 : i32> -// expected-error@below {{referenced attribute does not match previous definition}} -#attr1 = distinct[0]<43 : i32> +func.func @inline_distinct() -> () { + "foo"(){bar = distinct<42 : i32>} : () -> () // expected-error {{distinct attribute cannot be defined inline}} +} diff --git a/mlir/test/IR/test-builtin-distinct-attrs.mlir b/mlir/test/IR/test-builtin-distinct-attrs.mlir --- a/mlir/test/IR/test-builtin-distinct-attrs.mlir +++ b/mlir/test/IR/test-builtin-distinct-attrs.mlir @@ -1,35 +1,35 @@ // RUN: mlir-opt -allow-unregistered-dialect %s -test-distinct-attrs | FileCheck %s -// CHECK: #[[DIST0:.*]] = distinct[0]<42 : i32> -// CHECK: #[[DIST1:.*]] = distinct[1]<42 : i32> -#distinct = distinct[0]<42 : i32> -// CHECK: #[[DIST2:.*]] = distinct[2]<42 : i32> -// CHECK: #[[DIST3:.*]] = distinct[3]<42 : i32> -#distinct1 = distinct[1]<42 : i32> -// CHECK: #[[DIST4:.*]] = distinct[4]<43 : i32> -// CHECK: #[[DIST5:.*]] = distinct[5]<43 : i32> -#distinct2 = distinct[2]<43 : i32> -// CHECK: #[[DIST6:.*]] = distinct[6]<@foo_1> -// CHECK: #[[DIST7:.*]] = distinct[7]<@foo_1> -#distinct3 = distinct[3]<@foo_1> +// CHECK: #[[DIST0:.*]] = distinct<42 : i32> +// CHECK: #[[DIST1:.*]] = distinct<42 : i32> +#distinct = distinct<42 : i32> +// CHECK: #[[DIST2:.*]] = distinct<42 : i32> +// CHECK: #[[DIST3:.*]] = distinct<42 : i32> +#distinct1 = distinct<42 : i32> +// CHECK: #[[DIST4:.*]] = distinct<43 : i32> +// CHECK: #[[DIST5:.*]] = distinct<43 : i32> +#distinct2 = distinct<43 : i32> +// CHECK: #[[DIST6:.*]] = distinct<@foo_1> +// CHECK: #[[DIST7:.*]] = distinct<@foo_1> +#distinct3 = distinct<@foo_1> // Copies made for foo_2 -// CHECK: #[[DIST8:.*]] = distinct[8]<42 : i32> -// CHECK: #[[DIST9:.*]] = distinct[9]<42 : i32> -// CHECK: #[[DIST10:.*]] = distinct[10]<43 : i32> -// CHECK: #[[DIST11:.*]] = distinct[11]<@foo_1> +// CHECK: #[[DIST8:.*]] = distinct<42 : i32> +// CHECK: #[[DIST9:.*]] = distinct<42 : i32> +// CHECK: #[[DIST10:.*]] = distinct<43 : i32> +// CHECK: #[[DIST11:.*]] = distinct<@foo_1> // Copies made for foo_3 -// CHECK: #[[DIST12:.*]] = distinct[12]<42 : i32> -// CHECK: #[[DIST13:.*]] = distinct[13]<42 : i32> -// CHECK: #[[DIST14:.*]] = distinct[14]<43 : i32> -// CHECK: #[[DIST15:.*]] = distinct[15]<@foo_1> +// CHECK: #[[DIST12:.*]] = distinct<42 : i32> +// CHECK: #[[DIST13:.*]] = distinct<42 : i32> +// CHECK: #[[DIST14:.*]] = distinct<43 : i32> +// CHECK: #[[DIST15:.*]] = distinct<@foo_1> // Copies made for foo_4 -// CHECK: #[[DIST16:.*]] = distinct[16]<42 : i32> -// CHECK: #[[DIST17:.*]] = distinct[17]<42 : i32> -// CHECK: #[[DIST18:.*]] = distinct[18]<43 : i32> -// CHECK: #[[DIST19:.*]] = distinct[19]<@foo_1> +// CHECK: #[[DIST16:.*]] = distinct<42 : i32> +// CHECK: #[[DIST17:.*]] = distinct<42 : i32> +// CHECK: #[[DIST18:.*]] = distinct<43 : i32> +// CHECK: #[[DIST19:.*]] = distinct<@foo_1> // CHECK: @foo_1 func.func @foo_1() { diff --git a/mlir/test/IR/test-symbol-rauw.mlir b/mlir/test/IR/test-symbol-rauw.mlir --- a/mlir/test/IR/test-symbol-rauw.mlir +++ b/mlir/test/IR/test-symbol-rauw.mlir @@ -76,6 +76,9 @@ // ----- +// CHECK: #[[DIST:.*]] = distinct<@replaced_foo> +#distinct = distinct<@symbol_foo> + // Check that replacement works in any implementations of SubElements. module { // CHECK: func private @replaced_foo @@ -85,11 +88,11 @@ func.func @symbol_bar() { // CHECK: foo.op // CHECK-SAME: non_symbol_attr, - // CHECK-SAME: use = [#test.sub_elements_access<[@replaced_foo], @symbol_bar, @replaced_foo>, distinct[0]<@replaced_foo>], + // CHECK-SAME: use = [#test.sub_elements_access<[@replaced_foo], @symbol_bar, @replaced_foo>, #[[DIST]]], // CHECK-SAME: z_non_symbol_attr_3 "foo.op"() { non_symbol_attr, - use = [#test.sub_elements_access<[@symbol_foo],@symbol_bar,@symbol_foo>, distinct[0]<@symbol_foo>], + use = [#test.sub_elements_access<[@symbol_foo],@symbol_bar,@symbol_foo>, #distinct], z_non_symbol_attr_3 } : () -> () } diff --git a/mlir/test/IR/test-symbol-uses.mlir b/mlir/test/IR/test-symbol-uses.mlir --- a/mlir/test/IR/test-symbol-uses.mlir +++ b/mlir/test/IR/test-symbol-uses.mlir @@ -1,5 +1,7 @@ // RUN: mlir-opt -allow-unregistered-dialect %s -test-symbol-uses -split-input-file -verify-diagnostics +#distinct = distinct<@symbol_foo> + // Symbol references to the module itself don't affect uses of symbols within // its table. // expected-remark@below {{symbol_removable function successfully erased}} @@ -18,7 +20,7 @@ z_other_non_symbol_attr } : () -> () // expected-remark@+1 {{found use of symbol : @symbol_foo}} - "foo.op"() { use = distinct[0]<@symbol_foo> } : () -> () + "foo.op"() { use = #distinct } : () -> () } // expected-remark@below {{symbol has no uses}} diff --git a/mlir/unittests/IR/OpPropertiesTest.cpp b/mlir/unittests/IR/OpPropertiesTest.cpp --- a/mlir/unittests/IR/OpPropertiesTest.cpp +++ b/mlir/unittests/IR/OpPropertiesTest.cpp @@ -169,7 +169,7 @@ "<{a = -42 : i32, " "array = array, " "b = -4.200000e+01 : f32, " - "label = \"bar foo\"}> : () -> ()\n", + "label = \"bar foo\"}> : () -> ()", os.str().c_str()); } // Get a mutable reference to the properties for this operation and modify it