diff --git a/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp b/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp --- a/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp +++ b/mlir/lib/Dialect/Func/Transforms/DuplicateFunctionElimination.cpp @@ -48,33 +48,28 @@ return hash; } - static bool isEqual(const func::FuncOp cLhs, const func::FuncOp cRhs) { - if (cLhs == cRhs) { + static bool isEqual(func::FuncOp lhs, func::FuncOp rhs) { + if (lhs == rhs) return true; - } - if (cLhs == getTombstoneKey() || cLhs == getEmptyKey() || - cRhs == getTombstoneKey() || cRhs == getEmptyKey()) { + if (lhs == getTombstoneKey() || lhs == getEmptyKey() || + rhs == getTombstoneKey() || rhs == getEmptyKey()) + return false; + // Check discardable attributes equivalence + if (lhs->getDiscardableAttrDictionary() != + rhs->getDiscardableAttrDictionary()) return false; - } - // Check attributes equivalence, ignoring the symbol name. - if (cLhs->getAttrDictionary().size() != cRhs->getAttrDictionary().size()) { + // Check properties equivalence, ignoring the symbol name. + // Make a copy, so that we can erase the symbol name and perform the + // comparison. + auto pLhs = lhs.getProperties(); + auto pRhs = rhs.getProperties(); + pLhs.sym_name = nullptr; + pRhs.sym_name = nullptr; + if (pLhs != pRhs) return false; - } - func::FuncOp lhs = const_cast(cLhs); - StringAttr symNameAttrName = lhs.getSymNameAttrName(); - for (NamedAttribute namedAttr : cLhs->getAttrs()) { - StringAttr attrName = namedAttr.getName(); - if (attrName == symNameAttrName) { - continue; - } - if (namedAttr.getValue() != cRhs->getAttr(attrName)) { - return false; - } - } // Compare inner workings. - func::FuncOp rhs = const_cast(cRhs); return OperationEquivalence::isRegionEquivalentTo( &lhs.getBody(), &rhs.getBody(), OperationEquivalence::IgnoreLocations); } diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp @@ -395,7 +395,7 @@ bool isSameAttrList(spirv::StoreOp lhs, spirv::StoreOp rhs) const { return lhs->getDiscardableAttrDictionary() == rhs->getDiscardableAttrDictionary() && - lhs->hashProperties() == rhs->hashProperties(); + lhs.getProperties() == rhs.getProperties(); } // Returns a source value for the given block. diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h --- a/mlir/test/lib/Dialect/Test/TestDialect.h +++ b/mlir/test/lib/Dialect/Test/TestDialect.h @@ -66,6 +66,9 @@ /// offloaded to the client. std::shared_ptr label; int value; + bool operator==(const PropertiesWithCustomPrint &rhs) const { + return value == rhs.value && *label == *rhs.label; + } }; class MyPropStruct { public: @@ -77,6 +80,9 @@ mlir::Attribute attr, mlir::InFlightDiagnostic *diag); llvm::hash_code hash() const; + bool operator==(const MyPropStruct &rhs) const { + return content == rhs.content; + } }; } // namespace test diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -3433,6 +3433,10 @@ assert(!attrOrProperties.empty()); std::string declarations = " struct Properties {\n"; llvm::raw_string_ostream os(declarations); + std::string comparator = + " bool operator==(const Properties &rhs) const {\n" + " return \n"; + llvm::raw_string_ostream comparatorOs(comparator); for (const auto &attrOrProp : attrOrProperties) { if (const auto *namedProperty = attrOrProp.dyn_cast()) { @@ -3447,7 +3451,8 @@ << " " << name << "Ty " << name; if (prop.hasDefaultValue()) os << " = " << prop.getDefaultValue(); - + comparatorOs << " rhs." << name << " == this->" << name + << " &&\n"; // Emit accessors using the interface type. const char *accessorFmt = R"decl(; {0} get{1}() { @@ -3490,6 +3495,7 @@ } os << " using " << name << "Ty = " << storageType << ";\n" << " " << name << "Ty " << name << ";\n"; + comparatorOs << " rhs." << name << " == this->" << name << " &&\n"; // Emit accessors using the interface type. if (attr) { @@ -3509,8 +3515,15 @@ storageType); } } + comparatorOs << " true;\n }\n" + " bool operator!=(const Properties &rhs) const {\n" + " return !(*this == rhs);\n" + " }\n"; + comparatorOs.flush(); + os << comparator; os << " };\n"; os.flush(); + genericAdaptorBase.declare(std::move(declarations)); } genericAdaptorBase.declare(Visibility::Protected);