Index: mlir/include/mlir/IR/Attributes.h =================================================================== --- mlir/include/mlir/IR/Attributes.h +++ mlir/include/mlir/IR/Attributes.h @@ -266,7 +266,8 @@ }; template struct DenseMapInfo< - T, std::enable_if_t::value>> + T, std::enable_if_t::value && + !mlir::detail::IsInterface::value>> : public DenseMapInfo { static T getEmptyKey() { const void *pointer = llvm::DenseMapInfo::getEmptyKey(); Index: mlir/include/mlir/IR/OpDefinition.h =================================================================== --- mlir/include/mlir/IR/OpDefinition.h +++ mlir/include/mlir/IR/OpDefinition.h @@ -1963,8 +1963,9 @@ namespace llvm { template -struct DenseMapInfo< - T, std::enable_if_t::value>> { +struct DenseMapInfo::value && + !mlir::detail::IsInterface::value>> { static inline T getEmptyKey() { auto *pointer = llvm::DenseMapInfo::getEmptyKey(); return T::getFromOpaquePointer(pointer); Index: mlir/include/mlir/IR/Types.h =================================================================== --- mlir/include/mlir/IR/Types.h +++ mlir/include/mlir/IR/Types.h @@ -282,7 +282,8 @@ static bool isEqual(mlir::Type LHS, mlir::Type RHS) { return LHS == RHS; } }; template -struct DenseMapInfo::value>> +struct DenseMapInfo::value && + !mlir::detail::IsInterface::value>> : public DenseMapInfo { static T getEmptyKey() { const void *pointer = llvm::DenseMapInfo::getEmptyKey(); Index: mlir/include/mlir/Support/InterfaceSupport.h =================================================================== --- mlir/include/mlir/Support/InterfaceSupport.h +++ mlir/include/mlir/Support/InterfaceSupport.h @@ -78,6 +78,7 @@ Interface; template using ExternalModel = typename Traits::template ExternalModel; + using ValueType = ValueT; /// This is a special trait that registers a given interface with an object. template @@ -104,6 +105,9 @@ assert((!t || impl) && "expected value to provide interface instance"); } + /// Constructor for DenseMapInfo's empty key and tombstone key. + Interface(ValueT t, std::nullptr_t) : BaseType(t), impl(nullptr) {} + /// Support 'classof' by checking if the given object defines the concrete /// interface. static bool classof(ValueT t) { return ConcreteType::getInterfaceFor(t); } @@ -264,7 +268,37 @@ SmallVector> interfaces; }; +template class> class BaseTrait> +void isInterfaceImpl( + Interface &); + +template +using is_interface_t = decltype(isInterfaceImpl(std::declval())); + +template +using IsInterface = llvm::is_detected; + } // namespace detail } // namespace mlir +template +struct llvm::DenseMapInfo< + T, std::enable_if_t::value>> { + using ValueTypeInfo = llvm::DenseMapInfo; + + static T getEmptyKey() { return T(ValueTypeInfo::getEmptyKey(), nullptr); } + + static T getTombstoneKey() { + return T(ValueTypeInfo::getTombstoneKey(), nullptr); + } + + static unsigned getHashValue(T val) { + return ValueTypeInfo::getHashValue(val); + } + + static bool isEqual(T lhs, T rhs) { return ValueTypeInfo::isEqual(lhs, rhs); } +}; + #endif Index: mlir/unittests/IR/CMakeLists.txt =================================================================== --- mlir/unittests/IR/CMakeLists.txt +++ mlir/unittests/IR/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_unittest(MLIRIRTests AttributeTest.cpp DialectTest.cpp + InterfaceTest.cpp InterfaceAttachmentTest.cpp OperationSupportTest.cpp PatternMatchTest.cpp Index: mlir/unittests/IR/InterfaceTest.cpp =================================================================== --- mlir/unittests/IR/InterfaceTest.cpp +++ mlir/unittests/IR/InterfaceTest.cpp @@ -1,3 +1,45 @@ +//===- InterfaceTest.cpp - Test interfaces --------------------------------===// // -// Created by Markus Böck on 02/07/2022. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // +//===----------------------------------------------------------------------===// + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "gtest/gtest.h" + +#include "../../test/lib/Dialect/Test/TestAttributes.h" +#include "../../test/lib/Dialect/Test/TestDialect.h" +#include "../../test/lib/Dialect/Test/TestTypes.h" +#include "mlir/IR/OwningOpRef.h" + +using namespace mlir; +using namespace test; + +TEST(InterfaceTest, DenseMapKey) { + DialectRegistry registry; + registry.insert(); + + MLIRContext context(registry); + context.loadDialect(); + + OwningOpRef module = ModuleOp::create(UnknownLoc::get(&context)); + OpBuilder builder(module->getBody(), module->getBody()->begin()); + auto op1 = builder.create(builder.getUnknownLoc(), + builder.getI32Type()); + auto op2 = builder.create(builder.getUnknownLoc(), + builder.getI32Type()); + auto op3 = builder.create(builder.getUnknownLoc(), + builder.getI32Type()); + DenseSet set; + set.insert(op1); + set.insert(op2); + set.erase(op1); + EXPECT_FALSE(set.contains(op1)); + EXPECT_TRUE(set.contains(op2)); + EXPECT_FALSE(set.contains(op3)); +}