Index: mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h =================================================================== --- mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h +++ mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h @@ -17,8 +17,6 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpDefinition.h" -namespace mlir { #include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h.inc" -} // namespace mlir #endif // MLIR_DIALECT_AFFINE_IR_AFFINEMEMORYOPINTERFACES_H Index: mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td =================================================================== --- mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td +++ mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td @@ -16,6 +16,8 @@ include "mlir/IR/OpBase.td" def AffineReadOpInterface : OpInterface<"AffineReadOpInterface"> { + let cppNamespace = "mlir"; + let description = [{ Interface to query characteristics of read-like ops with affine restrictions. @@ -82,6 +84,8 @@ } def AffineWriteOpInterface : OpInterface<"AffineWriteOpInterface"> { + let cppNamespace = "mlir"; + let description = [{ Interface to query characteristics of write-like ops with affine restrictions. @@ -149,6 +153,8 @@ } def AffineMapAccessInterface : OpInterface<"AffineMapAccessInterface"> { + let cppNamespace = "mlir"; + let description = [{ Interface to query the AffineMap used to dereference and access a given memref. Implementers of this interface must operate on at least one Index: mlir/include/mlir/Dialect/Tosa/IR/TosaInterfaces.td =================================================================== --- mlir/include/mlir/Dialect/Tosa/IR/TosaInterfaces.td +++ mlir/include/mlir/Dialect/Tosa/IR/TosaInterfaces.td @@ -16,6 +16,8 @@ include "mlir/IR/OpBase.td" def TosaOpInterface : OpInterface<"TosaOp"> { + let cppNamespace = "mlir::tosa"; + let description = [{ Implemented by ops that correspond to the Tosa specification. }]; Index: mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h =================================================================== --- mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h +++ mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h @@ -26,14 +26,10 @@ namespace mlir { class PatternRewriter; - -namespace tosa { +} // namespace mlir #include "mlir/Dialect/Tosa/IR/TosaInterfaces.h.inc" -} // namespace tosa -} // namespace mlir - //===----------------------------------------------------------------------===// // Utility Functions //===----------------------------------------------------------------------===// Index: mlir/include/mlir/Support/InterfaceSupport.h =================================================================== --- mlir/include/mlir/Support/InterfaceSupport.h +++ mlir/include/mlir/Support/InterfaceSupport.h @@ -116,6 +116,9 @@ const Concept *getImpl() const { return impl; } Concept *getImpl() { return impl; } + /// Constructor for DenseMapInfo's empty key and tombstone key. + Interface(ValueT t, std::nullptr_t) : BaseType(t), impl(nullptr) {} + private: /// A pointer to the impl concept object. Concept *impl; Index: mlir/test/mlir-tblgen/op-interface.td =================================================================== --- mlir/test/mlir-tblgen/op-interface.td +++ mlir/test/mlir-tblgen/op-interface.td @@ -78,6 +78,8 @@ // DECL: template // DECL: int detail::TestOpInterfaceInterfaceTraits::Model::foo +// DECL: template <> struct DenseMapInfo<::TestOpInterface> + // DECL-LABEL: struct TestOpInterfaceVerifyTrait // DECL: verifyTrait Index: mlir/tools/mlir-tblgen/OpInterfacesGen.cpp =================================================================== --- mlir/tools/mlir-tblgen/OpInterfacesGen.cpp +++ mlir/tools/mlir-tblgen/OpInterfacesGen.cpp @@ -90,6 +90,7 @@ void emitTraitDecl(const Interface &interface, StringRef interfaceName, StringRef interfaceTraitsName); void emitInterfaceDecl(const Interface &interface); + void emitDenseMapInfo(const Interface &interface); /// The set of interface records to emit. std::vector defs; @@ -429,6 +430,37 @@ os << " };\n"; } +void InterfaceGenerator::emitDenseMapInfo( + const mlir::tblgen::Interface &interface) { + const char *format = R"( +namespace llvm { +template <> struct DenseMapInfo<{0}> {{ + using BaseTypeInfo = llvm::DenseMapInfo<{1}>; + + static {0} getEmptyKey() {{ + return {0}(BaseTypeInfo::getEmptyKey(), nullptr); + } + + static {0} getTombstoneKey() {{ + return {0}(BaseTypeInfo::getTombstoneKey(), nullptr); + } + + static unsigned getHashValue({0} val) {{ + return BaseTypeInfo::getHashValue(val); + } + + static bool isEqual({0} lhs, {0} rhs) {{ + return BaseTypeInfo::isEqual(lhs, rhs); + } +}; +})"; + + os << llvm::formatv(format, + interface.getCppNamespace() + "::" + interface.getName(), + valueType); + os << "\n"; +} + void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) { llvm::SmallVector namespaces; llvm::SplitString(interface.getCppNamespace(), namespaces, "::"); @@ -458,7 +490,8 @@ // Emit the main interface class declaration. os << llvm::formatv("class {0} : public ::mlir::{3}<{1}, detail::{2}> {\n" "public:\n" - " using ::mlir::{3}<{1}, detail::{2}>::{3};\n", + " using ::mlir::{3}<{1}, detail::{2}>::{3};\n" + " friend struct llvm::DenseMapInfo<{0}>;\n", interfaceName, interfaceName, interfaceTraitsName, interfaceBaseType); @@ -493,6 +526,8 @@ for (StringRef ns : llvm::reverse(namespaces)) os << "} // namespace " << ns << "\n"; + + emitDenseMapInfo(interface); } bool InterfaceGenerator::emitInterfaceDecls() { Index: mlir/unittests/IR/CMakeLists.txt =================================================================== --- mlir/unittests/IR/CMakeLists.txt +++ mlir/unittests/IR/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_unittest(MLIRIRTests AttributeTest.cpp DialectTest.cpp + InterfaceTest.cpp InterfaceAttachmentTest.cpp OperationSupportTest.cpp PatternMatchTest.cpp Index: mlir/unittests/IR/InterfaceTest.cpp =================================================================== --- /dev/null +++ mlir/unittests/IR/InterfaceTest.cpp @@ -0,0 +1,45 @@ +//===- InterfaceTest.cpp - Test interfaces --------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "gtest/gtest.h" + +#include "../../test/lib/Dialect/Test/TestAttributes.h" +#include "../../test/lib/Dialect/Test/TestDialect.h" +#include "../../test/lib/Dialect/Test/TestTypes.h" +#include "mlir/IR/OwningOpRef.h" + +using namespace mlir; +using namespace test; + +TEST(InterfaceTest, DenseMapKey) { + DialectRegistry registry; + registry.insert(); + + MLIRContext context(registry); + context.loadDialect(); + + OwningOpRef module = ModuleOp::create(UnknownLoc::get(&context)); + OpBuilder builder(module->getBody(), module->getBody()->begin()); + auto op1 = builder.create(builder.getUnknownLoc(), + builder.getI32Type()); + auto op2 = builder.create(builder.getUnknownLoc(), + builder.getI32Type()); + auto op3 = builder.create(builder.getUnknownLoc(), + builder.getI32Type()); + DenseSet set; + set.insert(op1); + set.insert(op2); + set.erase(op1); + EXPECT_FALSE(set.contains(op1)); + EXPECT_TRUE(set.contains(op2)); + EXPECT_FALSE(set.contains(op3)); +}