diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -2621,6 +2621,14 @@ // The name of the C++ Attribute class. string cppClassName = name # "Attr"; + // A code block used to build the value 'Type' of an Attribute when + // initializing its storage instance. This field is optional, and if not + // present the attribute will have its value type set to `NoneType`. This code + // block may reference any of the attributes parameters via + // `$_()">; @@ -2686,4 +2694,10 @@ }]; } +// This is a special parameter used for AttrDefs that represents a `mlir::Type` +// that is also used as the value `Type` of the attribute. Only one parameter +// of the attribute may be of this type. +class AttributeSelfTypeParameter : + AttrOrTypeParameter<"::mlir::Type", desc> {} + #endif // OP_BASE diff --git a/mlir/include/mlir/TableGen/AttrOrTypeDef.h b/mlir/include/mlir/TableGen/AttrOrTypeDef.h --- a/mlir/include/mlir/TableGen/AttrOrTypeDef.h +++ b/mlir/include/mlir/TableGen/AttrOrTypeDef.h @@ -130,7 +130,10 @@ // Returns whether the AttrOrTypeDef is defined. operator bool() const { return def != nullptr; } -private: + // Return the underlying def. + const llvm::Record *getDef() const { return def; } + +protected: const llvm::Record *def; // The builders of this type definition. @@ -145,6 +148,12 @@ class AttrDef : public AttrOrTypeDef { public: using AttrOrTypeDef::AttrOrTypeDef; + + // Returns the attributes value type builder code block, or None if it doesn't + // have one. + Optional getTypeBuilder() const; + + static bool classof(const AttrOrTypeDef *def); }; //===----------------------------------------------------------------------===// @@ -183,6 +192,9 @@ // Get the assembly syntax documentation. StringRef getSyntax() const; + // Return the underlying def of this parameter. + const llvm::Init *getDef() const; + private: /// The underlying tablegen parameter list this parameter is a part of. const llvm::DagInit *def; @@ -190,6 +202,17 @@ unsigned index; }; +//===----------------------------------------------------------------------===// +// AttributeSelfTypeParameter +//===----------------------------------------------------------------------===// + +// A wrapper class for the AttributeSelfTypeParameter tblgen class. This +// represents a parameter of mlir::Type that is the value type of an AttrDef. +class AttributeSelfTypeParameter : public AttrOrTypeParameter { +public: + static bool classof(const AttrOrTypeParameter *param); +}; + } // end namespace tblgen } // end namespace mlir diff --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp --- a/mlir/lib/TableGen/AttrOrTypeDef.cpp +++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp @@ -153,6 +153,18 @@ return getName() < other.getName(); } +//===----------------------------------------------------------------------===// +// AttrDef +//===----------------------------------------------------------------------===// + +Optional AttrDef::getTypeBuilder() const { + return def->getValueAsOptionalString("typeBuilder"); +} + +bool AttrDef::classof(const AttrOrTypeDef *def) { + return def->getDef()->isSubClassOf("AttrDef"); +} + //===----------------------------------------------------------------------===// // AttrOrTypeParameter //===----------------------------------------------------------------------===// @@ -219,3 +231,18 @@ llvm::PrintFatalError("Parameters DAG arguments must be either strings or " "defs which inherit from AttrOrTypeParameter"); } + +const llvm::Init *AttrOrTypeParameter::getDef() const { + return def->getArg(index); +} + +//===----------------------------------------------------------------------===// +// AttributeSelfTypeParameter +//===----------------------------------------------------------------------===// + +bool AttributeSelfTypeParameter::classof(const AttrOrTypeParameter *param) { + const llvm::Init *paramDef = param->getDef(); + if (auto *paramDefInit = dyn_cast(paramDef)) + return paramDefInit->getDef()->isSubClassOf("AttributeSelfTypeParameter"); + return false; +} diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td --- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td +++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td @@ -41,4 +41,17 @@ ); } +// An attribute testing AttributeSelfTypeParameter. +def AttrWithSelfTypeParam : Test_Attr<"AttrWithSelfTypeParam"> { + let mnemonic = "attr_with_self_type_param"; + let parameters = (ins AttributeSelfTypeParameter<"">:$type); +} + +// An attribute testing AttributeSelfTypeParameter. +def AttrWithTypeBuilder : Test_Attr<"AttrWithTypeBuilder"> { + let mnemonic = "attr_with_type_builder"; + let parameters = (ins "::mlir::IntegerAttr":$attr); + let typeBuilder = "$_attr.getType()"; +} + #endif // TEST_ATTRDEFS diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp --- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp +++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp @@ -23,6 +23,43 @@ using namespace mlir; using namespace mlir::test; +//===----------------------------------------------------------------------===// +// AttrWithSelfTypeParamAttr +//===----------------------------------------------------------------------===// + +Attribute AttrWithSelfTypeParamAttr::parse(MLIRContext *context, + DialectAsmParser &parser, + Type type) { + Type selfType; + if (parser.parseType(selfType)) + return Attribute(); + return get(context, selfType); +} + +void AttrWithSelfTypeParamAttr::print(DialectAsmPrinter &printer) const { + printer << "attr_with_self_type_param " << getType(); +} + +//===----------------------------------------------------------------------===// +// AttrWithTypeBuilderAttr +//===----------------------------------------------------------------------===// + +Attribute AttrWithTypeBuilderAttr::parse(MLIRContext *context, + DialectAsmParser &parser, Type type) { + IntegerAttr element; + if (parser.parseAttribute(element)) + return Attribute(); + return get(context, element); +} + +void AttrWithTypeBuilderAttr::print(DialectAsmPrinter &printer) const { + printer << "attr_with_type_builder " << getAttr(); +} + +//===----------------------------------------------------------------------===// +// CompoundAAttr +//===----------------------------------------------------------------------===// + Attribute CompoundAAttr::parse(MLIRContext *context, DialectAsmParser &parser, Type type) { int widthOfSomething; diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -419,6 +419,14 @@ let arguments = (ins AnyType:$x, AnyType:$y); } +def ResultHasSameTypeAsAttr : + TEST_Op<"result_has_same_type_as_attr", + [AllTypesMatch<["attr", "result"]>]> { + let arguments = (ins AnyAttr:$attr); + let results = (outs AnyType:$result); + let assemblyFormat = "$attr `->` type($result) attr-dict"; +} + def OperandZeroAndResultHaveSameType : TEST_Op<"operand0_and_result_have_same_type", [AllTypesMatch<["x", "res"]>]> { diff --git a/mlir/test/mlir-tblgen/attrdefs.td b/mlir/test/mlir-tblgen/attrdefs.td --- a/mlir/test/mlir-tblgen/attrdefs.td +++ b/mlir/test/mlir-tblgen/attrdefs.td @@ -25,7 +25,7 @@ // DEF: if (mnemonic == ::mlir::test::CompoundAAttr::getMnemonic()) return ::mlir::test::CompoundAAttr::parse(context, parser, type); // DEF-NEXT: if (mnemonic == ::mlir::test::IndexAttr::getMnemonic()) return ::mlir::test::IndexAttr::parse(context, parser, type); // DEF-NEXT: if (mnemonic == ::mlir::test::IntegerAttr::getMnemonic()) return ::mlir::test::IntegerAttr::parse(context, parser, type); -// DEF-NEXT: return ::mlir::Attribute(); +// DEF: return ::mlir::Attribute(); def Test_Dialect: Dialect { // DECL-NOT: TestDialect @@ -51,7 +51,7 @@ "::mlir::test::SimpleTypeA": $exampleTdType, "SomeCppStruct": $exampleCppType, ArrayRefParameter<"int", "Matrix dimensions">:$dims, - "::mlir::Type":$inner + AttributeSelfTypeParameter<"">:$inner ); let genVerifyDecl = 1; @@ -66,6 +66,20 @@ // DECL: int getWidthOfSomething() const; // DECL: ::mlir::test::SimpleTypeA getExampleTdType() const; // DECL: SomeCppStruct getExampleCppType() const; + +// Check that AttributeSelfTypeParameter is handled properly. +// DEF-LABEL: struct CompoundAAttrStorage +// DEF: CompoundAAttrStorage ( +// DEF-NEXT: : ::mlir::AttributeStorage(inner), + +// DEF: bool operator==(const KeyTy &key) const { +// DEF-NEXT: return key == KeyTy(widthOfSomething, exampleTdType, exampleCppType, dims, getType()); + +// DEF: static CompoundAAttrStorage *construct +// DEF: return new (allocator.allocate()) +// DEF-NEXT: CompoundAAttrStorage(widthOfSomething, exampleTdType, exampleCppType, dims, inner); + +// DEF: ::mlir::Type CompoundAAttr::getInner() const { return getImpl()->getType(); } } def C_IndexAttr : TestAttr<"Index"> { @@ -134,3 +148,14 @@ // DECL-NEXT: /// Return true if this is an unsigned integer type. // DECL-NEXT: bool isUnsigned() const { return getSignedness() == Unsigned; } } + +// An attribute testing AttributeSelfTypeParameter. +def F_AttrWithTypeBuilder : TestAttr<"AttrWithTypeBuilder"> { + let mnemonic = "attr_with_type_builder"; + let parameters = (ins "::mlir::IntegerAttr":$attr); + let typeBuilder = "$_attr.getType()"; +} + +// DEF-LABEL: struct AttrWithTypeBuilderAttrStorage +// DEF: AttrWithTypeBuilderAttrStorage (::mlir::IntegerAttr attr) +// DEF-NEXT: : ::mlir::AttributeStorage(attr.getType()), attr(attr) diff --git a/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir b/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir --- a/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir +++ b/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir @@ -3,3 +3,9 @@ // CHECK-LABEL: func private @compoundA() // CHECK-SAME: #test.cmpnd_a<1, !test.smpla, [5, 6]> func private @compoundA() attributes {foo = #test.cmpnd_a<1, !test.smpla, [5, 6]>} + +// CHECK: test.result_has_same_type_as_attr #test<"attr_with_self_type_param i32"> -> i32 +%a = test.result_has_same_type_as_attr #test<"attr_with_self_type_param i32"> -> i32 + +// CHECK: test.result_has_same_type_as_attr #test<"attr_with_type_builder 10 : i16"> -> i16 +%b = test.result_has_same_type_as_attr #test<"attr_with_type_builder 10 : i16"> -> i16 diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp @@ -484,31 +484,86 @@ } } +static std::string buildAttributeStorageParamInitializer( + const AttrOrTypeDef &def, ArrayRef parameters) { + std::string paramInitializer; + llvm::raw_string_ostream paramOS(paramInitializer); + paramOS << "::mlir::AttributeStorage("; + + // If this is an attribute, we need to check for value type initialization. + Optional selfParamIndex; + for (auto it : llvm::enumerate(parameters)) { + const auto *selfParam = dyn_cast(&it.value()); + if (!selfParam) + continue; + if (selfParamIndex) { + llvm::PrintFatalError(def.getLoc(), + "Only one attribute parameter can be marked as " + "AttributeSelfTypeParameter"); + } + paramOS << selfParam->getName(); + selfParamIndex = it.index(); + } + + // If we didn't find a self param, but the def has a type builder we use that + // to construct the type. + if (!selfParamIndex) { + const AttrDef &attrDef = cast(def); + if (Optional typeBuilder = attrDef.getTypeBuilder()) { + FmtContext fmtContext; + for (const AttrOrTypeParameter ¶m : parameters) + fmtContext.addSubst(("_" + param.getName()).str(), param.getName()); + paramOS << tgfmt(*typeBuilder, &fmtContext); + } + } + paramOS << ")"; + + // Append the parameters to the initializer. + for (auto it : llvm::enumerate(parameters)) + if (it.index() != selfParamIndex) + paramOS << llvm::formatv(", {0}({0})", it.value().getName()); + + return paramOS.str(); +} + void DefGenerator::emitStorageClass(const AttrOrTypeDef &def) { - SmallVector parameters; - def.getParameters(parameters); + SmallVector params; + def.getParameters(params); - // Collect the parameter names and types. - auto parameterNames = - map_range(parameters, [](AttrOrTypeParameter parameter) { - return parameter.getName(); - }); + // Collect the parameter types. auto parameterTypes = - map_range(parameters, [](AttrOrTypeParameter parameter) { + llvm::map_range(params, [](const AttrOrTypeParameter ¶meter) { return parameter.getCppType(); }); - auto parameterList = join(parameterNames, ", "); - auto parameterTypeList = join(parameterTypes, ", "); + std::string parameterTypeList = llvm::join(parameterTypes, ", "); + + // Collect the parameter initializer. + std::string paramInitializer; + if (isAttrGenerator) { + paramInitializer = buildAttributeStorageParamInitializer(def, params); + + } else { + llvm::raw_string_ostream initOS(paramInitializer); + llvm::interleaveComma(params, initOS, [&](const AttrOrTypeParameter &it) { + initOS << llvm::formatv("{0}({0})", it.getName()); + }); + } + + // Construct the parameter list that is used when a concrete instance of the + // storage exists. + auto nonStaticParameterNames = llvm::map_range(params, [](const auto ¶m) { + return isa(param) ? "getType()" + : param.getName(); + }); // 1) Emit most of the storage class up until the hashKey body. os << formatv( defStorageClassBeginStr, def.getStorageNamespace(), def.getStorageClassName(), ParamCommaFormatter(ParamCommaFormatter::EmitFormat::TypeNamePairs, - parameters, /*prependComma=*/false), - ParamCommaFormatter(ParamCommaFormatter::EmitFormat::TypeNameInitializer, - parameters, /*prependComma=*/false), - parameterList, parameterTypeList, valueType); + params, /*prependComma=*/false), + paramInitializer, llvm::join(nonStaticParameterNames, ", "), + parameterTypeList, valueType); // 2) Emit the haskKey method. os << " static ::llvm::hash_code hashKey(const KeyTy &key) {\n"; @@ -516,7 +571,7 @@ // Extract each parameter from the key. os << " return ::llvm::hash_combine("; llvm::interleaveComma( - llvm::seq(0, parameters.size()), os, + llvm::seq(0, params.size()), os, [&](unsigned it) { os << "std::get<" << it << ">(key)"; }); os << ");\n }\n"; @@ -534,9 +589,9 @@ // First, unbox the parameters. os << formatv(defStorageClassConstructorBeginStr, def.getStorageClassName(), valueType); - for (unsigned i = 0, e = parameters.size(); i < e; ++i) { + for (unsigned i = 0, e = params.size(); i < e; ++i) { os << formatv(" auto {0} = std::get<{1}>(key);\n", - parameters[i].getName(), i); + params[i].getName(), i); } // Second, reassign the parameter variables with allocation code, if it's @@ -544,14 +599,18 @@ emitStorageParameterAllocation(def, os); // Last, return an allocated copy. + auto parameterNames = llvm::map_range( + params, [](const auto ¶m) { return param.getName(); }); os << formatv(defStorageClassConstructorEndStr, def.getStorageClassName(), - parameterList); + llvm::join(parameterNames, ", ")); } // 4) Emit the parameters as storage class members. - for (auto parameter : parameters) { - os << " " << parameter.getCppType() << " " << parameter.getName() - << ";\n"; + for (const AttrOrTypeParameter ¶meter : params) { + // Attribute value types are not stored as fields in the storage. + if (!isa(parameter)) + os << " " << parameter.getCppType() << " " << parameter.getName() + << ";\n"; } os << " };\n"; @@ -707,10 +766,14 @@ // Otherwise, let the user define the exact accessor definition. if (def.genAccessors() && def.genStorageClass()) { for (const AttrOrTypeParameter ¶meter : parameters) { + StringRef paramStorageName = isa(parameter) + ? "getType()" + : parameter.getName(); + SmallString<16> name = parameter.getName(); name[0] = llvm::toUpper(name[0]); os << formatv("{0} {3}::get{1}() const {{ return getImpl()->{2}; }\n", - parameter.getCppType(), name, parameter.getName(), + parameter.getCppType(), name, paramStorageName, def.getCppClassName()); } }