diff --git a/mlir/docs/AttributesAndTypes.md b/mlir/docs/AttributesAndTypes.md --- a/mlir/docs/AttributesAndTypes.md +++ b/mlir/docs/AttributesAndTypes.md @@ -71,7 +71,7 @@ ```tablegen // Include the definition of the necessary tablegen constructs for defining -// our types. +// our types. include "mlir/IR/AttrTypeBase.td" // It's common to define a base classes for types in the same dialect. This @@ -108,7 +108,7 @@ ```tablegen // Include the definition of the necessary tablegen constructs for defining -// our attributes. +// our attributes. include "mlir/IR/AttrTypeBase.td" // It's common to define a base classes for attributes in the same dialect. This @@ -128,11 +128,11 @@ }]; /// Here we've defined two parameters, one is the `self` type of the attribute /// (i.e. the type of the Attribute itself), and the other is the integer value - /// of the attribute. + /// of the attribute. let parameters = (ins AttributeSelfTypeParameter<"">:$type, "APInt":$value); - + /// Here we've defined a custom builder for the type, that removes the need to pass - /// in an MLIRContext instance; as it can be infered from the `type`. + /// in an MLIRContext instance; as it can be infered from the `type`. let builders = [ AttrBuilderWithInferredContext<(ins "Type":$type, "const APInt &":$value), [{ @@ -147,7 +147,7 @@ /// #my.int<50> : !my.int<32> // a 32-bit integer of value 50. /// let assemblyFormat = "`<` $value `>`"; - + /// Indicate that our attribute will add additional verification to the parameters. let genVerifyDecl = 1; @@ -612,6 +612,10 @@ `ArrayRefParameter` uses `SmallVector` as its storage type. The parsers for these parameters are expected to return `FailureOr<$cppStorageType>`. +To add a custom conversion between the `cppStorageType` and the C++ type of the +parameter, parameters can override `convertFromStorage`, which by default is +`"$_self"` (i.e., it attempts an implicit conversion from `cppStorageType`). + ###### Optional Parameters Optional parameters in the assembly format can be indicated by setting @@ -1060,7 +1064,7 @@ #define GET_ATTRDEF_LIST #include "MyDialect/Attributes.cpp.inc" >(); - + /// Add the defined types to the dialect. addTypes< #define GET_TYPEDEF_LIST diff --git a/mlir/include/mlir/IR/AttrTypeBase.td b/mlir/include/mlir/IR/AttrTypeBase.td --- a/mlir/include/mlir/IR/AttrTypeBase.td +++ b/mlir/include/mlir/IR/AttrTypeBase.td @@ -295,6 +295,8 @@ // The C++ storage type of of this parameter if it is a reference, e.g. // `std::string` for `StringRef` or `SmallVector` for `ArrayRef`. string cppStorageType = ?; + // The C++ code to convert from the storage type to the parameter type. + string convertFromStorage = "$_self"; // One-line human-readable description of the argument. string summary = desc; // The format string for the asm syntax (documentation only). diff --git a/mlir/include/mlir/TableGen/AttrOrTypeDef.h b/mlir/include/mlir/TableGen/AttrOrTypeDef.h --- a/mlir/include/mlir/TableGen/AttrOrTypeDef.h +++ b/mlir/include/mlir/TableGen/AttrOrTypeDef.h @@ -76,6 +76,9 @@ /// Get the C++ storage type of this parameter. StringRef getCppStorageType() const; + /// Get the C++ code to convert from the storage type to the parameter type. + StringRef getConvertFromStorage() const; + /// Get an optional C++ parameter parser. Optional getParser() const; diff --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp --- a/mlir/lib/TableGen/AttrOrTypeDef.cpp +++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp @@ -229,14 +229,9 @@ } StringRef AttrOrTypeParameter::getCppType() const { - llvm::Init *parameterType = getDef(); - if (auto *stringType = dyn_cast(parameterType)) + if (auto *stringType = dyn_cast(getDef())) return stringType->getValue(); - if (auto *param = dyn_cast(parameterType)) - return param->getDef()->getValueAsString("cppType"); - llvm::PrintFatalError( - "Parameters DAG arguments must be either strings or defs " - "which inherit from AttrOrTypeParameter\n"); + return getDefValue("cppType").getValue(); } StringRef AttrOrTypeParameter::getCppAccessorType() const { @@ -248,6 +243,11 @@ return getDefValue("cppStorageType").value_or(getCppType()); } +StringRef AttrOrTypeParameter::getConvertFromStorage() const { + return getDefValue("convertFromStorage") + .getValueOr("$_self"); +} + Optional AttrOrTypeParameter::getParser() const { return getDefValue("parser"); } diff --git a/mlir/test/mlir-tblgen/attr-or-type-format.td b/mlir/test/mlir-tblgen/attr-or-type-format.td --- a/mlir/test/mlir-tblgen/attr-or-type-format.td +++ b/mlir/test/mlir-tblgen/attr-or-type-format.td @@ -55,8 +55,8 @@ // ATTR: if (odsParser.parseRParen()) // ATTR: return {}; // ATTR: return TestAAttr::get(odsParser.getContext(), -// ATTR: *_result_value, -// ATTR: *_result_complex); +// ATTR: (*_result_value), +// ATTR: (*_result_complex)); // ATTR: } // ATTR: void TestAAttr::print(::mlir::AsmPrinter &odsPrinter) const { @@ -114,8 +114,8 @@ // ATTR: return {}; // ATTR: } // ATTR: return TestBAttr::get(odsParser.getContext(), -// ATTR: *_result_v0, -// ATTR: *_result_v1); +// ATTR: (*_result_v0), +// ATTR: (*_result_v1)); // ATTR: } // ATTR: void TestBAttr::print(::mlir::AsmPrinter &odsPrinter) const { @@ -151,8 +151,8 @@ // ATTR: if (::mlir::failed(_result_v1)) // ATTR: return {}; // ATTR: return TestFAttr::get(odsParser.getContext(), -// ATTR: *_result_v0, -// ATTR: *_result_v1); +// ATTR: (*_result_v0), +// ATTR: (*_result_v1)); // ATTR: } def AttrC : TestAttr<"TestF"> { @@ -278,10 +278,10 @@ // TYPE: if (::mlir::failed(_result_v3)) // TYPE: return {}; // TYPE: return TestDType::get(odsParser.getContext(), -// TYPE: *_result_v0, -// TYPE: *_result_v1, -// TYPE: *_result_v2, -// TYPE: *_result_v3); +// TYPE: (*_result_v0), +// TYPE: (*_result_v1), +// TYPE: (*_result_v2), +// TYPE: (*_result_v3)); // TYPE: } // TYPE: void TestDType::print(::mlir::AsmPrinter &odsPrinter) const { @@ -369,10 +369,10 @@ // TYPE: return {}; // TYPE: } // TYPE: return TestEType::get(odsParser.getContext(), -// TYPE: *_result_v0, -// TYPE: *_result_v1, -// TYPE: *_result_v2, -// TYPE: *_result_v3); +// TYPE: (*_result_v0), +// TYPE: (*_result_v1), +// TYPE: (*_result_v2), +// TYPE: (*_result_v3)); // TYPE: } // TYPE: void TestEType::print(::mlir::AsmPrinter &odsPrinter) const { @@ -535,3 +535,19 @@ let mnemonic = "type_j"; let assemblyFormat = "custom($a) custom($b, ref($a))"; } + +// TYPE: ::mlir::Type TestMType::parse +// TYPE: FailureOr _result_a +// TYPE: return TestMType::get +// TYPE: static_cast((*_result_a)) + +def ConvertFromStorageParameter : TypeParameter<"int", ""> { + let cppStorageType = "float"; + let convertFromStorage = "static_cast($_self)"; +} + +def TypeK : TestType<"TestM"> { + let parameters = (ins ConvertFromStorageParameter:$a); + let mnemonic = "type_k"; + let assemblyFormat = "$a"; +} diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp @@ -297,18 +297,21 @@ } for (const AttrOrTypeParameter ¶m : params) { os << ",\n "; + std::string paramSelfStr; + llvm::raw_string_ostream selfOs(paramSelfStr); if (param.isOptional()) { - os << formatv("_result_{0}.value_or(", param.getName()); + selfOs << formatv("(_result_{0}.value_or(", param.getName()); if (Optional defaultValue = param.getDefaultValue()) - os << tgfmt(*defaultValue, &ctx); + selfOs << tgfmt(*defaultValue, &ctx); else - os << param.getCppStorageType() << "()"; - os << ")"; + selfOs << param.getCppStorageType() << "()"; + selfOs << "))"; } else if (isa(param)) { - os << tgfmt("$_type", &ctx); + selfOs << tgfmt("$_type", &ctx); } else { - os << formatv("*_result_{0}", param.getName()); + selfOs << formatv("(*_result_{0})", param.getName()); } + os << tgfmt(param.getConvertFromStorage(), &ctx.withSelf(selfOs.str())); } os << ");"; }