Index: mlir/examples/toy/Ch4/include/toy/ShapeInferenceInterface.h =================================================================== --- mlir/examples/toy/Ch4/include/toy/ShapeInferenceInterface.h +++ mlir/examples/toy/Ch4/include/toy/ShapeInferenceInterface.h @@ -16,13 +16,7 @@ #include "mlir/IR/OpDefinition.h" -namespace mlir { -namespace toy { - /// Include the auto-generated declarations. #include "toy/ShapeInferenceOpInterfaces.h.inc" -} // namespace toy -} // namespace mlir - #endif // MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_ Index: mlir/examples/toy/Ch4/include/toy/ShapeInferenceInterface.td =================================================================== --- mlir/examples/toy/Ch4/include/toy/ShapeInferenceInterface.td +++ mlir/examples/toy/Ch4/include/toy/ShapeInferenceInterface.td @@ -16,6 +16,8 @@ include "mlir/IR/OpBase.td" def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> { + let cppNamespace = "mlir::toy"; + let description = [{ Interface to access a registered method to infer the return types for an operation that can be used during type inference. Index: mlir/examples/toy/Ch5/include/toy/ShapeInferenceInterface.h =================================================================== --- mlir/examples/toy/Ch5/include/toy/ShapeInferenceInterface.h +++ mlir/examples/toy/Ch5/include/toy/ShapeInferenceInterface.h @@ -16,13 +16,7 @@ #include "mlir/IR/OpDefinition.h" -namespace mlir { -namespace toy { - /// Include the auto-generated declarations. #include "toy/ShapeInferenceOpInterfaces.h.inc" -} // namespace toy -} // namespace mlir - #endif // MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_ Index: mlir/examples/toy/Ch5/include/toy/ShapeInferenceInterface.td =================================================================== --- mlir/examples/toy/Ch5/include/toy/ShapeInferenceInterface.td +++ mlir/examples/toy/Ch5/include/toy/ShapeInferenceInterface.td @@ -16,6 +16,8 @@ include "mlir/IR/OpBase.td" def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> { + let cppNamespace = "mlir::toy"; + let description = [{ Interface to access a registered method to infer the return types for an operation that can be used during type inference. Index: mlir/examples/toy/Ch6/include/toy/ShapeInferenceInterface.h =================================================================== --- mlir/examples/toy/Ch6/include/toy/ShapeInferenceInterface.h +++ mlir/examples/toy/Ch6/include/toy/ShapeInferenceInterface.h @@ -16,13 +16,7 @@ #include "mlir/IR/OpDefinition.h" -namespace mlir { -namespace toy { - /// Include the auto-generated declarations. #include "toy/ShapeInferenceOpInterfaces.h.inc" -} // namespace toy -} // namespace mlir - #endif // MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_ Index: mlir/examples/toy/Ch6/include/toy/ShapeInferenceInterface.td =================================================================== --- mlir/examples/toy/Ch6/include/toy/ShapeInferenceInterface.td +++ mlir/examples/toy/Ch6/include/toy/ShapeInferenceInterface.td @@ -16,6 +16,8 @@ include "mlir/IR/OpBase.td" def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> { + let cppNamespace = "mlir::toy"; + let description = [{ Interface to access a registered method to infer the return types for an operation that can be used during type inference. Index: mlir/examples/toy/Ch7/include/toy/ShapeInferenceInterface.h =================================================================== --- mlir/examples/toy/Ch7/include/toy/ShapeInferenceInterface.h +++ mlir/examples/toy/Ch7/include/toy/ShapeInferenceInterface.h @@ -16,13 +16,7 @@ #include "mlir/IR/OpDefinition.h" -namespace mlir { -namespace toy { - /// Include the auto-generated declarations. #include "toy/ShapeInferenceOpInterfaces.h.inc" -} // namespace toy -} // namespace mlir - #endif // MLIR_TUTORIAL_TOY_SHAPEINFERENCEINTERFACE_H_ Index: mlir/examples/toy/Ch7/include/toy/ShapeInferenceInterface.td =================================================================== --- mlir/examples/toy/Ch7/include/toy/ShapeInferenceInterface.td +++ mlir/examples/toy/Ch7/include/toy/ShapeInferenceInterface.td @@ -16,6 +16,8 @@ include "mlir/IR/OpBase.td" def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> { + let cppNamespace = "mlir::toy"; + let description = [{ Interface to access a registered method to infer the return types for an operation that can be used during type inference. Index: mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h =================================================================== --- mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h +++ mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h @@ -17,8 +17,6 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpDefinition.h" -namespace mlir { #include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h.inc" -} // namespace mlir #endif // MLIR_DIALECT_AFFINE_IR_AFFINEMEMORYOPINTERFACES_H Index: mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td =================================================================== --- mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td +++ mlir/include/mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td @@ -16,6 +16,8 @@ include "mlir/IR/OpBase.td" def AffineReadOpInterface : OpInterface<"AffineReadOpInterface"> { + let cppNamespace = "mlir"; + let description = [{ Interface to query characteristics of read-like ops with affine restrictions. @@ -82,6 +84,8 @@ } def AffineWriteOpInterface : OpInterface<"AffineWriteOpInterface"> { + let cppNamespace = "mlir"; + let description = [{ Interface to query characteristics of write-like ops with affine restrictions. @@ -149,6 +153,8 @@ } def AffineMapAccessInterface : OpInterface<"AffineMapAccessInterface"> { + let cppNamespace = "mlir"; + let description = [{ Interface to query the AffineMap used to dereference and access a given memref. Implementers of this interface must operate on at least one Index: mlir/include/mlir/Dialect/Tosa/IR/TosaInterfaces.td =================================================================== --- mlir/include/mlir/Dialect/Tosa/IR/TosaInterfaces.td +++ mlir/include/mlir/Dialect/Tosa/IR/TosaInterfaces.td @@ -16,6 +16,8 @@ include "mlir/IR/OpBase.td" def TosaOpInterface : OpInterface<"TosaOp"> { + let cppNamespace = "mlir::tosa"; + let description = [{ Implemented by ops that correspond to the Tosa specification. }]; Index: mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h =================================================================== --- mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h +++ mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h @@ -26,14 +26,10 @@ namespace mlir { class PatternRewriter; - -namespace tosa { +} // namespace mlir #include "mlir/Dialect/Tosa/IR/TosaInterfaces.h.inc" -} // namespace tosa -} // namespace mlir - //===----------------------------------------------------------------------===// // Utility Functions //===----------------------------------------------------------------------===// Index: mlir/include/mlir/Support/InterfaceSupport.h =================================================================== --- mlir/include/mlir/Support/InterfaceSupport.h +++ mlir/include/mlir/Support/InterfaceSupport.h @@ -78,6 +78,7 @@ Interface; template using ExternalModel = typename Traits::template ExternalModel; + using ValueType = ValueT; /// This is a special trait that registers a given interface with an object. template @@ -116,6 +117,9 @@ const Concept *getImpl() const { return impl; } Concept *getImpl() { return impl; } + /// Constructor for DenseMapInfo's empty key and tombstone key. + Interface(ValueT t, std::nullptr_t) : BaseType(t), impl(nullptr) {} + private: /// A pointer to the impl concept object. Concept *impl; @@ -267,4 +271,34 @@ } // namespace detail } // namespace mlir +/// Convenience macro to define a 'llvm::DenseMapInfo' specialization for the +/// given interface. Requires this specialization to be a friend of the +/// interface class. This can be done by adding: +/// 'friend struct llvm::DenseMapInfo' in the class body. +#define MLIR_DEFINE_INTERFACE_DENSE_MAP_INFO(INTERFACE_NAME) \ + template <> \ + struct llvm::DenseMapInfo { \ + using ValueTypeInfo = llvm::DenseMapInfo; \ + \ + static INTERFACE_NAME getEmptyKey() { \ + return INTERFACE_NAME(ValueTypeInfo::getEmptyKey(), nullptr); \ + } \ + \ + static INTERFACE_NAME getTombstoneKey() { \ + return INTERFACE_NAME(ValueTypeInfo::getTombstoneKey(), nullptr); \ + } \ + \ + static unsigned getHashValue(INTERFACE_NAME val) { \ + return ValueTypeInfo::getHashValue(val); \ + } \ + \ + static bool isEqual(INTERFACE_NAME lhs, INTERFACE_NAME rhs) { \ + return ValueTypeInfo::isEqual(lhs, rhs); \ + } \ + }; + +// struct used to check whether an interface implementation file is +// included in the global namespace. +struct MLIR_TEST_GLOBAL_NAMESPACE_INCLUDE; + #endif Index: mlir/test/mlir-tblgen/op-interface.td =================================================================== --- mlir/test/mlir-tblgen/op-interface.td +++ mlir/test/mlir-tblgen/op-interface.td @@ -78,6 +78,8 @@ // DECL: template // DECL: int detail::TestOpInterfaceInterfaceTraits::Model::foo +// DECL: MLIR_DEFINE_INTERFACE_DENSE_MAP_INFO(::TestOpInterface) + // DECL-LABEL: struct TestOpInterfaceVerifyTrait // DECL: verifyTrait Index: mlir/tools/mlir-tblgen/OpInterfacesGen.cpp =================================================================== --- mlir/tools/mlir-tblgen/OpInterfacesGen.cpp +++ mlir/tools/mlir-tblgen/OpInterfacesGen.cpp @@ -90,6 +90,7 @@ void emitTraitDecl(const Interface &interface, StringRef interfaceName, StringRef interfaceTraitsName); void emitInterfaceDecl(const Interface &interface); + void emitDenseMapInfo(const Interface &interface); /// The set of interface records to emit. std::vector defs; @@ -429,6 +430,13 @@ os << " };\n"; } +void InterfaceGenerator::emitDenseMapInfo( + const mlir::tblgen::Interface &interface) { + os << llvm::formatv("MLIR_DEFINE_INTERFACE_DENSE_MAP_INFO({0})", + interface.getCppNamespace() + "::" + interface.getName()); + os << "\n"; +} + void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) { llvm::SmallVector namespaces; llvm::SplitString(interface.getCppNamespace(), namespaces, "::"); @@ -458,7 +466,8 @@ // Emit the main interface class declaration. os << llvm::formatv("class {0} : public ::mlir::{3}<{1}, detail::{2}> {\n" "public:\n" - " using ::mlir::{3}<{1}, detail::{2}>::{3};\n", + " using ::mlir::{3}<{1}, detail::{2}>::{3};\n" + " friend struct llvm::DenseMapInfo<{0}>;\n", interfaceName, interfaceName, interfaceTraitsName, interfaceBaseType); @@ -493,11 +502,25 @@ for (StringRef ns : llvm::reverse(namespaces)) os << "} // namespace " << ns << "\n"; + + emitDenseMapInfo(interface); } bool InterfaceGenerator::emitInterfaceDecls() { llvm::emitSourceFileHeader("Interface Declarations", os); + // Emit preamble asserting the file is being included in the global namespace. + os << R"( +struct MLIR_TEST_GLOBAL_NAMESPACE_INCLUDE; + +static_assert(std::is_same::value, + "Including interface declarations inside of a namespace is no " + "longer supported. Use the 'cppNamespace' field in TableGen " + "and include the file in the global namespace instead."); + +)"; + for (const auto *def : defs) emitInterfaceDecl(Interface(def)); return false; Index: mlir/unittests/IR/CMakeLists.txt =================================================================== --- mlir/unittests/IR/CMakeLists.txt +++ mlir/unittests/IR/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_unittest(MLIRIRTests AttributeTest.cpp DialectTest.cpp + InterfaceTest.cpp InterfaceAttachmentTest.cpp OperationSupportTest.cpp PatternMatchTest.cpp Index: mlir/unittests/IR/InterfaceTest.cpp =================================================================== --- /dev/null +++ mlir/unittests/IR/InterfaceTest.cpp @@ -0,0 +1,45 @@ +//===- InterfaceTest.cpp - Test interfaces --------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "gtest/gtest.h" + +#include "../../test/lib/Dialect/Test/TestAttributes.h" +#include "../../test/lib/Dialect/Test/TestDialect.h" +#include "../../test/lib/Dialect/Test/TestTypes.h" +#include "mlir/IR/OwningOpRef.h" + +using namespace mlir; +using namespace test; + +TEST(InterfaceTest, DenseMapKey) { + DialectRegistry registry; + registry.insert(); + + MLIRContext context(registry); + context.loadDialect(); + + OwningOpRef module = ModuleOp::create(UnknownLoc::get(&context)); + OpBuilder builder(module->getBody(), module->getBody()->begin()); + auto op1 = builder.create(builder.getUnknownLoc(), + builder.getI32Type()); + auto op2 = builder.create(builder.getUnknownLoc(), + builder.getI32Type()); + auto op3 = builder.create(builder.getUnknownLoc(), + builder.getI32Type()); + DenseSet set; + set.insert(op1); + set.insert(op2); + set.erase(op1); + EXPECT_FALSE(set.contains(op1)); + EXPECT_TRUE(set.contains(op2)); + EXPECT_FALSE(set.contains(op3)); +}