Index: mlir/docs/PDLL.md =================================================================== --- mlir/docs/PDLL.md +++ mlir/docs/PDLL.md @@ -1225,7 +1225,7 @@ - Imported `Type` constraints utilize the `cppClassName` field for native type translation. * `AttrInterface`/`OpInterface`/`TypeInterface` constraints - - Imported interfaces utilize the `cppClassName` field for native type translation. + - Imported interfaces utilize the `cppInterfaceName` field for native type translation. #### Defining Constraints Inline Index: mlir/include/mlir/IR/OpBase.td =================================================================== --- mlir/include/mlir/IR/OpBase.td +++ mlir/include/mlir/IR/OpBase.td @@ -1949,7 +1949,7 @@ string description = ""; // The name given to the c++ interface class. - string cppClassName = name; + string cppInterfaceName = name; // The C++ namespace that this interface should be placed into. // @@ -1970,13 +1970,22 @@ } // AttrInterface represents an interface registered to an attribute. -class AttrInterface : Interface, InterfaceTrait; +class AttrInterface : Interface, InterfaceTrait, + Attr()">, + name # " instance"> +{ + let storageType = cppNamespace # "::" # cppClassName; + let returnType = storageType; + let convertFromStorage = "$_self"; +} // OpInterface represents an interface registered to an operation. class OpInterface : Interface, OpInterfaceTrait; // TypeInterface represents an interface registered to a type. -class TypeInterface : Interface, InterfaceTrait; +class TypeInterface : Interface, InterfaceTrait, + Type()">, + name # " instance", cppNamespace # "::" # name>; // Whether to declare the interface methods in the user entity's header. This // class simply wraps an Interface but is used to indicate that the method @@ -1992,27 +2001,27 @@ class DeclareAttrInterfaceMethods overridenMethods = []> : DeclareInterfaceMethods, - AttrInterface { + AttrInterface { let description = interface.description; - let cppClassName = interface.cppClassName; + let cppInterfaceName = interface.cppInterfaceName; let cppNamespace = interface.cppNamespace; let methods = interface.methods; } class DeclareOpInterfaceMethods overridenMethods = []> : DeclareInterfaceMethods, - OpInterface { + OpInterface { let description = interface.description; - let cppClassName = interface.cppClassName; + let cppInterfaceName = interface.cppInterfaceName; let cppNamespace = interface.cppNamespace; let methods = interface.methods; } class DeclareTypeInterfaceMethods overridenMethods = []> : DeclareInterfaceMethods, - TypeInterface { + TypeInterface { let description = interface.description; - let cppClassName = interface.cppClassName; + let cppInterfaceName = interface.cppInterfaceName; let cppNamespace = interface.cppNamespace; let methods = interface.methods; } Index: mlir/lib/TableGen/Interfaces.cpp =================================================================== --- mlir/lib/TableGen/Interfaces.cpp +++ mlir/lib/TableGen/Interfaces.cpp @@ -81,7 +81,7 @@ // Return the name of this interface. StringRef Interface::getName() const { - return def->getValueAsString("cppClassName"); + return def->getValueAsString("cppInterfaceName"); } // Return the C++ namespace of this interface. Index: mlir/lib/Tools/PDLL/Parser/Parser.cpp =================================================================== --- mlir/lib/Tools/PDLL/Parser/Parser.cpp +++ mlir/lib/Tools/PDLL/Parser/Parser.cpp @@ -893,9 +893,9 @@ constraint.getCPPClassName())); } } - /// Interfaces. + /// OpInterfaces. ast::Type opTy = ast::OperationType::get(ctx); - for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("Interface")) { + for (llvm::Record *def : tdRecords.getAllDerivedDefinitions("OpInterface")) { StringRef name = def->getName(); if (def->isAnonymous() || curDeclScope->lookup(name) || def->isSubClassOf("DeclareInterfaceMethods")) @@ -904,7 +904,7 @@ std::string cppClassName = llvm::formatv("{0}::{1}", def->getValueAsString("cppNamespace"), - def->getValueAsString("cppClassName")) + def->getValueAsString("cppInterfaceName")) .str(); std::string codeBlock = llvm::formatv("return ::mlir::success(llvm::isa<{0}>(self));", @@ -913,18 +913,8 @@ std::string desc = processAndFormatDoc(def->getValueAsString("description")); - if (def->isSubClassOf("OpInterface")) { - decls.push_back(createODSNativePDLLConstraintDecl( - name, codeBlock, loc, opTy, cppClassName, desc)); - } else if (def->isSubClassOf("AttrInterface")) { - decls.push_back( - createODSNativePDLLConstraintDecl( - name, codeBlock, loc, attrTy, cppClassName, desc)); - } else if (def->isSubClassOf("TypeInterface")) { - decls.push_back( - createODSNativePDLLConstraintDecl( - name, codeBlock, loc, typeTy, cppClassName, desc)); - } + decls.push_back(createODSNativePDLLConstraintDecl( + name, codeBlock, loc, opTy, cppClassName, desc)); } } Index: mlir/test/mlir-pdll/Parser/include_td.pdll =================================================================== --- mlir/test/mlir-pdll/Parser/include_td.pdll +++ mlir/test/mlir-pdll/Parser/include_td.pdll @@ -32,21 +32,21 @@ // CHECK-NEXT: CppClass: ::mlir::IntegerType // CHECK-NEXT: } -// CHECK: UserConstraintDecl {{.*}} Name ResultType> Code(self));> +// CHECK: UserConstraintDecl {{.*}} Name ResultType> Code()));> // CHECK: `Inputs` // CHECK: `-VariableDecl {{.*}} Name Type // CHECK: `Constraints` // CHECK: `-AttrConstraintDecl +// CHECK: UserConstraintDecl {{.*}} Name ResultType> Code()));> +// CHECK: `Inputs` +// CHECK: `-VariableDecl {{.*}} Name Type +// CHECK: `Constraints` +// CHECK: `-TypeConstraintDecl {{.*}} + // CHECK: UserConstraintDecl {{.*}} Name ResultType> Code(self));> // CHECK: `Inputs` // CHECK: `-VariableDecl {{.*}} Name Type // CHECK: `Constraints` // CHECK: `-OpConstraintDecl // CHECK: `-OpNameDecl - -// CHECK: UserConstraintDecl {{.*}} Name ResultType> Code(self));> -// CHECK: `Inputs` -// CHECK: `-VariableDecl {{.*}} Name Type -// CHECK: `Constraints` -// CHECK: `-TypeConstraintDecl {{.*}} Index: mlir/tools/mlir-tblgen/OpFormatGen.cpp =================================================================== --- mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -2304,7 +2304,7 @@ // DeclareOpInterfaceMethods // and the like. // TODO: Add hasCppInterface check. - if (auto name = def.getValueAsOptionalString("cppClassName")) { + if (auto name = def.getValueAsOptionalString("cppInterfaceName")) { if (*name == "InferTypeOpInterface" && def.getValueAsString("cppNamespace") == "::mlir") canInferResultTypes = true;