diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md --- a/mlir/docs/OpDefinitions.md +++ b/mlir/docs/OpDefinitions.md @@ -1441,8 +1441,9 @@ ### TypeParameter tablegen class This is used to further specify attributes about each of the types parameters. -It includes documentation (`summary` and `syntax`), the C++ type to use, and a -custom allocator to use in the storage constructor method. +It includes documentation (`summary` and `syntax`), the C++ type to use, a +custom allocator to use in the storage constructor method, and a custom +comparator to decide if two instances of the parameter type are equal. ```tablegen // DO NOT DO THIS! @@ -1472,6 +1473,11 @@ - `$_allocator` is the TypeStorageAllocator in which to allocate objects. - `$_dst` is the variable in which to place the allocated data. +The `comparator` code block has the following substitutions: + +- `$_lhs` is an instance of the parameter type. +- `$_rhs` is an instance of the parameter type. + MLIR includes several specialized classes for common situations: - `StringRefParameter` for StringRefs. diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -2673,6 +2673,8 @@ class AttrOrTypeParameter { // Custom memory allocation code for storage constructor. code allocator = ?; + // Custom comparator used to compare two instances for equality. + code comparator = ?; // The C++ type of this parameter. string cppType = type; // One-line human-readable description of the argument. @@ -2689,6 +2691,12 @@ let allocator = [{$_dst = $_allocator.copyInto($_self);}]; } +// For APFloats, which require comparison. +class APFloatParameter : + AttrOrTypeParameter<"::llvm::APFloat", desc> { + let comparator = "$_lhs.bitwiseIsEqual($_rhs)"; +} + // For standard ArrayRefs, which require allocation. class ArrayRefParameter : AttrOrTypeParameter<"::llvm::ArrayRef<" # arrayOf # ">", desc> { diff --git a/mlir/include/mlir/TableGen/AttrOrTypeDef.h b/mlir/include/mlir/TableGen/AttrOrTypeDef.h --- a/mlir/include/mlir/TableGen/AttrOrTypeDef.h +++ b/mlir/include/mlir/TableGen/AttrOrTypeDef.h @@ -183,6 +183,9 @@ // If specified, get the custom allocator code for this parameter. Optional getAllocator() const; + // If specified, get the custom comparator code for this parameter. + Optional getComparator() const; + // Get the C++ type of this parameter. StringRef getCppType() const; diff --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp --- a/mlir/lib/TableGen/AttrOrTypeDef.cpp +++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp @@ -177,22 +177,18 @@ llvm::Init *parameterType = def->getArg(index); if (isa(parameterType)) return Optional(); + if (auto *param = dyn_cast(parameterType)) + return param->getDef()->getValueAsOptionalString("allocator"); + llvm::PrintFatalError("Parameters DAG arguments must be either strings or " + "defs which inherit from AttrOrTypeParameter\n"); +} - if (auto *param = dyn_cast(parameterType)) { - llvm::RecordVal *code = param->getDef()->getValue("allocator"); - if (!code) - return Optional(); - if (llvm::StringInit *ci = dyn_cast(code->getValue())) - return ci->getValue(); - if (isa(code->getValue())) - return Optional(); - - llvm::PrintFatalError( - param->getDef()->getLoc(), - "Record `" + def->getArgName(index)->getValue() + - "', field `printer' does not have a code initializer!"); - } - +Optional AttrOrTypeParameter::getComparator() const { + llvm::Init *parameterType = def->getArg(index); + if (isa(parameterType)) + return Optional(); + if (auto *param = dyn_cast(parameterType)) + return param->getDef()->getValueAsOptionalString("comparator"); llvm::PrintFatalError("Parameters DAG arguments must be either strings or " "defs which inherit from AttrOrTypeParameter\n"); } diff --git a/mlir/test/mlir-tblgen/attrdefs.td b/mlir/test/mlir-tblgen/attrdefs.td --- a/mlir/test/mlir-tblgen/attrdefs.td +++ b/mlir/test/mlir-tblgen/attrdefs.td @@ -53,7 +53,7 @@ ins "int":$widthOfSomething, "::mlir::test::SimpleTypeA": $exampleTdType, - "SomeCppStruct": $exampleCppType, + APFloatParameter<"">: $apFloat, ArrayRefParameter<"int", "Matrix dimensions">:$dims, AttributeSelfTypeParameter<"">:$inner ); @@ -61,8 +61,8 @@ let genVerifyDecl = 1; // DECL-LABEL: class CompoundAAttr : public ::mlir::Attribute -// DECL: static CompoundAAttr getChecked(llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::MLIRContext *context, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef dims, ::mlir::Type inner); -// DECL: static ::mlir::LogicalResult verify(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef dims, ::mlir::Type inner); +// DECL: static CompoundAAttr getChecked(llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::MLIRContext *context, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, ::llvm::APFloat apFloat, ::llvm::ArrayRef dims, ::mlir::Type inner); +// DECL: static ::mlir::LogicalResult verify(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, ::llvm::APFloat apFloat, ::llvm::ArrayRef dims, ::mlir::Type inner); // DECL: static constexpr ::llvm::StringLiteral getMnemonic() { // DECL: return ::llvm::StringLiteral("cmpnd_a"); // DECL: } @@ -71,7 +71,7 @@ // DECL: void print(::mlir::DialectAsmPrinter &printer) const; // DECL: int getWidthOfSomething() const; // DECL: ::mlir::test::SimpleTypeA getExampleTdType() const; -// DECL: SomeCppStruct getExampleCppType() const; +// DECL: ::llvm::APFloat getApFloat() const; // Check that AttributeSelfTypeParameter is handled properly. // DEF-LABEL: struct CompoundAAttrStorage @@ -79,11 +79,21 @@ // DEF-NEXT: : ::mlir::AttributeStorage(inner), // DEF: bool operator==(const KeyTy &key) const { -// DEF-NEXT: return key == KeyTy(widthOfSomething, exampleTdType, exampleCppType, dims, getType()); +// DEF-NEXT: if (!(widthOfSomething == std::get<0>(key))) +// DEF-NEXT: return false; +// DEF-NEXT: if (!(exampleTdType == std::get<1>(key))) +// DEF-NEXT: return false; +// DEF-NEXT: if (!(apFloat.bitwiseIsEqual(std::get<2>(key)))) +// DEF-NEXT: return false; +// DEF-NEXT: if (!(dims == std::get<3>(key))) +// DEF-NEXT: return false; +// DEF-NEXT: if (!(getType() == std::get<4>(key))) +// DEF-NEXT: return false; +// DEF-NEXT: return true; // DEF: static CompoundAAttrStorage *construct // DEF: return new (allocator.allocate()) -// DEF-NEXT: CompoundAAttrStorage(widthOfSomething, exampleTdType, exampleCppType, dims, inner); +// DEF-NEXT: CompoundAAttrStorage(widthOfSomething, exampleTdType, apFloat, dims, inner); // DEF: ::mlir::Type CompoundAAttr::getInner() const { return getImpl()->getType(); } } diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp @@ -432,22 +432,16 @@ /// {1}: Storage class c++ name. /// {2}: Parameters parameters. /// {3}: Parameter initializer string. -/// {4}: Parameter name list. -/// {5}: Parameter types. -/// {6}: The name of the base value type, e.g. Attribute or Type. +/// {4}: Parameter types. +/// {5}: The name of the base value type, e.g. Attribute or Type. static const char *const defStorageClassBeginStr = R"( namespace {0} {{ - struct {1} : public ::mlir::{6}Storage {{ + struct {1} : public ::mlir::{5}Storage {{ {1} ({2}) : {3} {{ } /// The hash key is a tuple of the parameter types. - using KeyTy = std::tuple<{5}>; - - /// Define the comparison function for the key type. - bool operator==(const KeyTy &key) const {{ - return key == KeyTy({4}); - } + using KeyTy = std::tuple<{4}>; )"; /// The storage class' constructor template. @@ -555,23 +549,34 @@ }); } - // Construct the parameter list that is used when a concrete instance of the - // storage exists. - auto nonStaticParameterNames = llvm::map_range(params, [](const auto ¶m) { - return isa(param) ? "getType()" - : param.getName(); - }); - - // 1) Emit most of the storage class up until the hashKey body. + // * Emit most of the storage class up until the hashKey body. os << formatv( defStorageClassBeginStr, def.getStorageNamespace(), def.getStorageClassName(), ParamCommaFormatter(ParamCommaFormatter::EmitFormat::TypeNamePairs, params, /*prependComma=*/false), - paramInitializer, llvm::join(nonStaticParameterNames, ", "), - parameterTypeList, valueType); + paramInitializer, parameterTypeList, valueType); + + // * Emit the comparison method. + os << " bool operator==(const KeyTy &key) const {\n"; + for (auto it : llvm::enumerate(params)) { + os << " if (!("; + + // Build the comparator context. + bool isSelfType = isa(it.value()); + FmtContext context; + context.addSubst("_lhs", isSelfType ? "getType()" : it.value().getName()) + .addSubst("_rhs", "std::get<" + Twine(it.index()) + ">(key)"); + + // Use the parameter specified comparator if possible, otherwise default to + // operator==. + Optional comparator = it.value().getComparator(); + os << tgfmt(comparator ? *comparator : "$_lhs == $_rhs", &context); + os << "))\n return false;\n"; + } + os << " return true;\n }\n"; - // 2) Emit the haskKey method. + // * Emit the haskKey method. os << " static ::llvm::hash_code hashKey(const KeyTy &key) {\n"; // Extract each parameter from the key. @@ -581,7 +586,7 @@ [&](unsigned it) { os << "std::get<" << it << ">(key)"; }); os << ");\n }\n"; - // 3) Emit the construct method. + // * Emit the construct method. // If user wants to build the storage constructor themselves, declare it // here and then they can write the definition elsewhere. @@ -611,7 +616,7 @@ llvm::join(parameterNames, ", ")); } - // 4) Emit the parameters as storage class members. + // * Emit the parameters as storage class members. for (const AttrOrTypeParameter ¶meter : params) { // Attribute value types are not stored as fields in the storage. if (!isa(parameter))