diff --git a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticBase.td b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticBase.td --- a/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticBase.td +++ b/mlir/include/mlir/Dialect/Arithmetic/IR/ArithmeticBase.td @@ -9,6 +9,7 @@ #ifndef ARITHMETIC_BASE #define ARITHMETIC_BASE +include "mlir/IR/EnumAttr.td" include "mlir/IR/OpBase.td" def Arithmetic_Dialect : Dialect { diff --git a/mlir/include/mlir/Dialect/GPU/ParallelLoopMapperAttr.td b/mlir/include/mlir/Dialect/GPU/ParallelLoopMapperAttr.td --- a/mlir/include/mlir/Dialect/GPU/ParallelLoopMapperAttr.td +++ b/mlir/include/mlir/Dialect/GPU/ParallelLoopMapperAttr.td @@ -15,6 +15,7 @@ #define PARALLEL_LOOP_MAPPER_ATTR include "mlir/Dialect/GPU/GPUBase.td" +include "mlir/IR/EnumAttr.td" def BlockX : I64EnumAttrCase<"BlockX", 0>; def BlockY : I64EnumAttrCase<"BlockY", 1>; diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -14,6 +14,7 @@ #ifndef LLVMIR_OP_BASE #define LLVMIR_OP_BASE +include "mlir/IR/EnumAttr.td" include "mlir/IR/OpBase.td" include "mlir/Interfaces/SideEffectInterfaces.td" diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -13,6 +13,7 @@ #ifndef VECTOR_OPS #define VECTOR_OPS +include "mlir/IR/EnumAttr.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" diff --git a/mlir/include/mlir/IR/EnumAttr.td b/mlir/include/mlir/IR/EnumAttr.td --- a/mlir/include/mlir/IR/EnumAttr.td +++ b/mlir/include/mlir/IR/EnumAttr.td @@ -11,6 +11,252 @@ include "mlir/IR/AttrTypeBase.td" +//===----------------------------------------------------------------------===// +// Enum attribute kinds + +// Additional information for an enum attribute case. +class EnumAttrCaseInfo { + // The C++ enumerant symbol. + string symbol = sym; + + // The C++ enumerant value. + // If less than zero, there will be no explicit discriminator values assigned + // to enumerators in the generated enum class. + int value = intVal; + + // The string representation of the enumerant. May be the same as symbol. + string str = strVal; +} + +// An enum attribute case stored with IntegerAttr, which has an integer value, +// its representation as a string and a C++ symbol name which may be different. +class IntEnumAttrCaseBase : + EnumAttrCaseInfo, + SignlessIntegerAttrBase { + let predicate = + CPred<"$_self.cast<::mlir::IntegerAttr>().getInt() == " # intVal>; +} + +// Cases of integer enum attributes with a specific type. By default, the string +// representation is the same as the C++ symbol name. +class I32EnumAttrCase + : IntEnumAttrCaseBase; +class I64EnumAttrCase + : IntEnumAttrCaseBase; + +// A bit enum case stored with an IntegerAttr. `val` here is *not* the ordinal +// number of a bit that is set. It is an integer value with bits set to match +// the case. +class BitEnumAttrCaseBase : + EnumAttrCaseInfo, + SignlessIntegerAttrBase; + +// A bit enum case stored with a 32-bit IntegerAttr. `val` here is *not* the +// ordinal number of a bit that is set. It is a 32-bit integer value with bits +// set to match the case. +class I32BitEnumAttrCase + : BitEnumAttrCaseBase; + +// A bit enum case stored with a 64-bit IntegerAttr. `val` here is *not* the +// ordinal number of a bit that is set. It is a 64-bit integer value with bits +// bits set to match the case. +class I64BitEnumAttrCase + : BitEnumAttrCaseBase; + +// The special bit enum case for I32 with no bits set (i.e. value = 0). +class I32BitEnumAttrCaseNone + : I32BitEnumAttrCase; + +// The special bit enum case for I64 with no bits set (i.e. value = 0). +class I64BitEnumAttrCaseNone + : I64BitEnumAttrCase; + +// A bit enum case for a single bit, specified by a bit position. +// The pos argument refers to the index of the bit, and is limited +// to be in the range [0, bitwidth). +class BitEnumAttrCaseBit + : BitEnumAttrCaseBase { + assert !and(!ge(pos, 0), !lt(pos, intType.bitwidth)), + "bit position larger than underlying storage"; +} + +// A bit enum case for a single bit in a 32-bit enum, specified by the +// bit position. +class I32BitEnumAttrCaseBit + : BitEnumAttrCaseBit; + +// A bit enum case for a single bit in a 64-bit enum, specified by the +// bit position. +class I64BitEnumAttrCaseBit + : BitEnumAttrCaseBit; + + +// A bit enum case for a group/list of previously declared cases, providing +// a convenient alias for that group. +class BitEnumAttrCaseGroup cases, string str = sym> + : BitEnumAttrCaseBase; + +// A 32-bit enum case for a group/list of previously declared cases, providing +// a convenient alias for that group. +class I32BitEnumAttrCaseGroup cases, + string str = sym> + : BitEnumAttrCaseGroup; + +// A 64-bit enum case for a group/list of previously declared cases, providing +// a convenient alias for that group. +class I64BitEnumAttrCaseGroup cases, + string str = sym> + : BitEnumAttrCaseGroup; + +// Additional information for an enum attribute. +class EnumAttrInfo< + string name, list cases, Attr baseClass> : + Attr { + // The C++ enum class name + string className = name; + + // List of all accepted cases + list enumerants = cases; + + // The following fields are only used by the EnumsGen backend to generate + // an enum class definition and conversion utility functions. + + // The underlying type for the C++ enum class. An empty string mean the + // underlying type is not explicitly specified. + string underlyingType = ""; + + // The name of the utility function that converts a value of the underlying + // type to the corresponding symbol. It will have the following signature: + // + // ```c++ + // llvm::Optional<> (); + // ``` + string underlyingToSymbolFnName = "symbolize" # name; + + // The name of the utility function that converts a string to the + // corresponding symbol. It will have the following signature: + // + // ```c++ + // llvm::Optional<> (llvm::StringRef); + // ``` + string stringToSymbolFnName = "symbolize" # name; + + // The name of the utility function that converts a symbol to the + // corresponding string. It will have the following signature: + // + // ```c++ + // (); + // ``` + string symbolToStringFnName = "stringify" # name; + string symbolToStringFnRetType = "::llvm::StringRef"; + + // The name of the utility function that returns the max enum value used + // within the enum class. It will have the following signature: + // + // ```c++ + // static constexpr unsigned (); + // ``` + string maxEnumValFnName = "getMaxEnumValFor" # name; + + // Generate specialized Attribute class + bit genSpecializedAttr = 1; + // The underlying Attribute class, which holds the enum value + Attr baseAttrClass = baseClass; + // The name of specialized Enum Attribute class + string specializedAttrClassName = name # Attr; + + // Override Attr class fields for specialized class + let predicate = !if(genSpecializedAttr, + CPred<"$_self.isa<" # cppNamespace # "::" # specializedAttrClassName # ">()">, + baseAttrClass.predicate); + let storageType = !if(genSpecializedAttr, + cppNamespace # "::" # specializedAttrClassName, + baseAttrClass.storageType); + let returnType = !if(genSpecializedAttr, + cppNamespace # "::" # className, + baseAttrClass.returnType); + let constBuilderCall = !if(genSpecializedAttr, + cppNamespace # "::" # specializedAttrClassName # "::get($_builder.getContext(), $0)", + baseAttrClass.constBuilderCall); + let valueType = baseAttrClass.valueType; +} + +// An enum attribute backed by IntegerAttr. +// +// Op attributes of this kind are stored as IntegerAttr. Extra verification will +// be generated on the integer though: only the values of the allowed cases are +// permitted as the integer value. +class IntEnumAttrBase cases, string summary> : + SignlessIntegerAttrBase { + let predicate = And<[ + SignlessIntegerAttrBase.predicate, + Or]>; +} + +class IntEnumAttr cases> : + EnumAttrInfo>; + +class I32EnumAttr cases> : + IntEnumAttr { + let underlyingType = "uint32_t"; +} +class I64EnumAttr cases> : + IntEnumAttr { + let underlyingType = "uint64_t"; +} + +// A bit enum stored with an IntegerAttr. +// +// Op attributes of this kind are stored as IntegerAttr. Extra verification will +// be generated on the integer to make sure only allowed bits are set. Besides, +// helper methods are generated to parse a string separated with a specified +// delimiter to a symbol and vice versa. +class BitEnumAttrBase cases, + string summary> + : SignlessIntegerAttrBase { + let predicate = And<[ + SignlessIntegerAttrBase.predicate, + // Make sure we don't have unknown bit set. + CPred<"!($_self.cast<::mlir::IntegerAttr>().getValue().getZExtValue() & (~(" + # !interleave(!foreach(case, cases, case.value # "u"), "|") # + ")))"> + ]>; +} + +class BitEnumAttr cases> + : EnumAttrInfo> { + // Determine "valid" bits from enum cases for error checking + int validBits = !foldl(0, cases, value, bitcase, !or(value, bitcase.value)); + + // We need to return a string because we may concatenate symbols for multiple + // bits together. + let symbolToStringFnRetType = "std::string"; + + // The delimiter used to separate bit enum cases in strings. + string separator = "|"; +} + +class I32BitEnumAttr cases> + : BitEnumAttr { + let underlyingType = "uint32_t"; +} + +class I64BitEnumAttr cases> + : BitEnumAttr { + let underlyingType = "uint64_t"; +} + // A C++ enum as an attribute parameter. The parameter implements a parser and // printer for the enum by dispatching calls to `stringToSymbol` and // `symbolToString`. diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1213,252 +1213,6 @@ let isOptional = 1; } -//===----------------------------------------------------------------------===// -// Enum attribute kinds - -// Additional information for an enum attribute case. -class EnumAttrCaseInfo { - // The C++ enumerant symbol. - string symbol = sym; - - // The C++ enumerant value. - // If less than zero, there will be no explicit discriminator values assigned - // to enumerators in the generated enum class. - int value = intVal; - - // The string representation of the enumerant. May be the same as symbol. - string str = strVal; -} - -// An enum attribute case stored with IntegerAttr, which has an integer value, -// its representation as a string and a C++ symbol name which may be different. -class IntEnumAttrCaseBase : - EnumAttrCaseInfo, - SignlessIntegerAttrBase { - let predicate = - CPred<"$_self.cast<::mlir::IntegerAttr>().getInt() == " # intVal>; -} - -// Cases of integer enum attributes with a specific type. By default, the string -// representation is the same as the C++ symbol name. -class I32EnumAttrCase - : IntEnumAttrCaseBase; -class I64EnumAttrCase - : IntEnumAttrCaseBase; - -// A bit enum case stored with an IntegerAttr. `val` here is *not* the ordinal -// number of a bit that is set. It is an integer value with bits set to match -// the case. -class BitEnumAttrCaseBase : - EnumAttrCaseInfo, - SignlessIntegerAttrBase; - -// A bit enum case stored with a 32-bit IntegerAttr. `val` here is *not* the -// ordinal number of a bit that is set. It is a 32-bit integer value with bits -// set to match the case. -class I32BitEnumAttrCase - : BitEnumAttrCaseBase; - -// A bit enum case stored with a 64-bit IntegerAttr. `val` here is *not* the -// ordinal number of a bit that is set. It is a 64-bit integer value with bits -// bits set to match the case. -class I64BitEnumAttrCase - : BitEnumAttrCaseBase; - -// The special bit enum case for I32 with no bits set (i.e. value = 0). -class I32BitEnumAttrCaseNone - : I32BitEnumAttrCase; - -// The special bit enum case for I64 with no bits set (i.e. value = 0). -class I64BitEnumAttrCaseNone - : I64BitEnumAttrCase; - -// A bit enum case for a single bit, specified by a bit position. -// The pos argument refers to the index of the bit, and is limited -// to be in the range [0, bitwidth). -class BitEnumAttrCaseBit - : BitEnumAttrCaseBase { - assert !and(!ge(pos, 0), !lt(pos, intType.bitwidth)), - "bit position larger than underlying storage"; -} - -// A bit enum case for a single bit in a 32-bit enum, specified by the -// bit position. -class I32BitEnumAttrCaseBit - : BitEnumAttrCaseBit; - -// A bit enum case for a single bit in a 64-bit enum, specified by the -// bit position. -class I64BitEnumAttrCaseBit - : BitEnumAttrCaseBit; - - -// A bit enum case for a group/list of previously declared cases, providing -// a convenient alias for that group. -class BitEnumAttrCaseGroup cases, string str = sym> - : BitEnumAttrCaseBase; - -// A 32-bit enum case for a group/list of previously declared cases, providing -// a convenient alias for that group. -class I32BitEnumAttrCaseGroup cases, - string str = sym> - : BitEnumAttrCaseGroup; - -// A 64-bit enum case for a group/list of previously declared cases, providing -// a convenient alias for that group. -class I64BitEnumAttrCaseGroup cases, - string str = sym> - : BitEnumAttrCaseGroup; - -// Additional information for an enum attribute. -class EnumAttrInfo< - string name, list cases, Attr baseClass> : - Attr { - // The C++ enum class name - string className = name; - - // List of all accepted cases - list enumerants = cases; - - // The following fields are only used by the EnumsGen backend to generate - // an enum class definition and conversion utility functions. - - // The underlying type for the C++ enum class. An empty string mean the - // underlying type is not explicitly specified. - string underlyingType = ""; - - // The name of the utility function that converts a value of the underlying - // type to the corresponding symbol. It will have the following signature: - // - // ```c++ - // llvm::Optional<> (); - // ``` - string underlyingToSymbolFnName = "symbolize" # name; - - // The name of the utility function that converts a string to the - // corresponding symbol. It will have the following signature: - // - // ```c++ - // llvm::Optional<> (llvm::StringRef); - // ``` - string stringToSymbolFnName = "symbolize" # name; - - // The name of the utility function that converts a symbol to the - // corresponding string. It will have the following signature: - // - // ```c++ - // (); - // ``` - string symbolToStringFnName = "stringify" # name; - string symbolToStringFnRetType = "::llvm::StringRef"; - - // The name of the utility function that returns the max enum value used - // within the enum class. It will have the following signature: - // - // ```c++ - // static constexpr unsigned (); - // ``` - string maxEnumValFnName = "getMaxEnumValFor" # name; - - // Generate specialized Attribute class - bit genSpecializedAttr = 1; - // The underlying Attribute class, which holds the enum value - Attr baseAttrClass = baseClass; - // The name of specialized Enum Attribute class - string specializedAttrClassName = name # Attr; - - // Override Attr class fields for specialized class - let predicate = !if(genSpecializedAttr, - CPred<"$_self.isa<" # cppNamespace # "::" # specializedAttrClassName # ">()">, - baseAttrClass.predicate); - let storageType = !if(genSpecializedAttr, - cppNamespace # "::" # specializedAttrClassName, - baseAttrClass.storageType); - let returnType = !if(genSpecializedAttr, - cppNamespace # "::" # className, - baseAttrClass.returnType); - let constBuilderCall = !if(genSpecializedAttr, - cppNamespace # "::" # specializedAttrClassName # "::get($_builder.getContext(), $0)", - baseAttrClass.constBuilderCall); - let valueType = baseAttrClass.valueType; -} - -// An enum attribute backed by IntegerAttr. -// -// Op attributes of this kind are stored as IntegerAttr. Extra verification will -// be generated on the integer though: only the values of the allowed cases are -// permitted as the integer value. -class IntEnumAttrBase cases, string summary> : - SignlessIntegerAttrBase { - let predicate = And<[ - SignlessIntegerAttrBase.predicate, - Or]>; -} - -class IntEnumAttr cases> : - EnumAttrInfo>; - -class I32EnumAttr cases> : - IntEnumAttr { - let underlyingType = "uint32_t"; -} -class I64EnumAttr cases> : - IntEnumAttr { - let underlyingType = "uint64_t"; -} - -// A bit enum stored with an IntegerAttr. -// -// Op attributes of this kind are stored as IntegerAttr. Extra verification will -// be generated on the integer to make sure only allowed bits are set. Besides, -// helper methods are generated to parse a string separated with a specified -// delimiter to a symbol and vice versa. -class BitEnumAttrBase cases, - string summary> - : SignlessIntegerAttrBase { - let predicate = And<[ - SignlessIntegerAttrBase.predicate, - // Make sure we don't have unknown bit set. - CPred<"!($_self.cast<::mlir::IntegerAttr>().getValue().getZExtValue() & (~(" - # !interleave(!foreach(case, cases, case.value # "u"), "|") # - ")))"> - ]>; -} - -class BitEnumAttr cases> - : EnumAttrInfo> { - // Determine "valid" bits from enum cases for error checking - int validBits = !foldl(0, cases, value, bitcase, !or(value, bitcase.value)); - - // We need to return a string because we may concatenate symbols for multiple - // bits together. - let symbolToStringFnRetType = "std::string"; - - // The delimiter used to separate bit enum cases in strings. - string separator = "|"; -} - -class I32BitEnumAttr cases> - : BitEnumAttr { - let underlyingType = "uint32_t"; -} - -class I64BitEnumAttr cases> - : BitEnumAttr { - let underlyingType = "uint64_t"; -} - //===----------------------------------------------------------------------===// // Composite attribute kinds diff --git a/mlir/test/mlir-tblgen/op-attribute.td b/mlir/test/mlir-tblgen/op-attribute.td --- a/mlir/test/mlir-tblgen/op-attribute.td +++ b/mlir/test/mlir-tblgen/op-attribute.td @@ -3,6 +3,7 @@ // RUN: mlir-tblgen -print-records -I %S/../../include %s | FileCheck %s --check-prefix=RECORD include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/EnumAttr.td" include "mlir/IR/OpBase.td" def Test_Dialect : Dialect { diff --git a/mlir/unittests/TableGen/enums.td b/mlir/unittests/TableGen/enums.td --- a/mlir/unittests/TableGen/enums.td +++ b/mlir/unittests/TableGen/enums.td @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +include "mlir/IR/EnumAttr.td" include "mlir/IR/OpBase.td" def CaseA: I32EnumAttrCase<"CaseA", 0>;