diff --git a/mlir/include/mlir/IR/AttrTypeBase.td b/mlir/include/mlir/IR/AttrTypeBase.td --- a/mlir/include/mlir/IR/AttrTypeBase.td +++ b/mlir/include/mlir/IR/AttrTypeBase.td @@ -295,6 +295,8 @@ // The C++ storage type of of this parameter if it is a reference, e.g. // `std::string` for `StringRef` or `SmallVector` for `ArrayRef`. string cppStorageType = ?; + // The C++ code to convert from the storage type to the parameter type. + string convertFromStorage = "$_self"; // One-line human-readable description of the argument. string summary = desc; // The format string for the asm syntax (documentation only). diff --git a/mlir/include/mlir/TableGen/AttrOrTypeDef.h b/mlir/include/mlir/TableGen/AttrOrTypeDef.h --- a/mlir/include/mlir/TableGen/AttrOrTypeDef.h +++ b/mlir/include/mlir/TableGen/AttrOrTypeDef.h @@ -76,6 +76,9 @@ /// Get the C++ storage type of this parameter. StringRef getCppStorageType() const; + /// Get the C++ code to convert from the storage type to the parameter type. + StringRef getConvertFromStorage() const; + /// Get an optional C++ parameter parser. Optional getParser() const; diff --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp --- a/mlir/lib/TableGen/AttrOrTypeDef.cpp +++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp @@ -229,14 +229,9 @@ } StringRef AttrOrTypeParameter::getCppType() const { - llvm::Init *parameterType = getDef(); - if (auto *stringType = dyn_cast(parameterType)) + if (auto *stringType = dyn_cast(getDef())) return stringType->getValue(); - if (auto *param = dyn_cast(parameterType)) - return param->getDef()->getValueAsString("cppType"); - llvm::PrintFatalError( - "Parameters DAG arguments must be either strings or defs " - "which inherit from AttrOrTypeParameter\n"); + return getDefValue("cppType").getValue(); } StringRef AttrOrTypeParameter::getCppAccessorType() const { @@ -248,6 +243,11 @@ return getDefValue("cppStorageType").value_or(getCppType()); } +StringRef AttrOrTypeParameter::getConvertFromStorage() const { + return getDefValue("convertFromStorage") + .getValueOr("$_self"); +} + Optional AttrOrTypeParameter::getParser() const { return getDefValue("parser"); } diff --git a/mlir/test/mlir-tblgen/attr-or-type-format.td b/mlir/test/mlir-tblgen/attr-or-type-format.td --- a/mlir/test/mlir-tblgen/attr-or-type-format.td +++ b/mlir/test/mlir-tblgen/attr-or-type-format.td @@ -55,8 +55,8 @@ // ATTR: if (odsParser.parseRParen()) // ATTR: return {}; // ATTR: return TestAAttr::get(odsParser.getContext(), -// ATTR: *_result_value, -// ATTR: *_result_complex); +// ATTR: (*_result_value), +// ATTR: (*_result_complex)); // ATTR: } // ATTR: void TestAAttr::print(::mlir::AsmPrinter &odsPrinter) const { @@ -114,8 +114,8 @@ // ATTR: return {}; // ATTR: } // ATTR: return TestBAttr::get(odsParser.getContext(), -// ATTR: *_result_v0, -// ATTR: *_result_v1); +// ATTR: (*_result_v0), +// ATTR: (*_result_v1)); // ATTR: } // ATTR: void TestBAttr::print(::mlir::AsmPrinter &odsPrinter) const { @@ -151,8 +151,8 @@ // ATTR: if (::mlir::failed(_result_v1)) // ATTR: return {}; // ATTR: return TestFAttr::get(odsParser.getContext(), -// ATTR: *_result_v0, -// ATTR: *_result_v1); +// ATTR: (*_result_v0), +// ATTR: (*_result_v1)); // ATTR: } def AttrC : TestAttr<"TestF"> { @@ -278,10 +278,10 @@ // TYPE: if (::mlir::failed(_result_v3)) // TYPE: return {}; // TYPE: return TestDType::get(odsParser.getContext(), -// TYPE: *_result_v0, -// TYPE: *_result_v1, -// TYPE: *_result_v2, -// TYPE: *_result_v3); +// TYPE: (*_result_v0), +// TYPE: (*_result_v1), +// TYPE: (*_result_v2), +// TYPE: (*_result_v3)); // TYPE: } // TYPE: void TestDType::print(::mlir::AsmPrinter &odsPrinter) const { @@ -369,10 +369,10 @@ // TYPE: return {}; // TYPE: } // TYPE: return TestEType::get(odsParser.getContext(), -// TYPE: *_result_v0, -// TYPE: *_result_v1, -// TYPE: *_result_v2, -// TYPE: *_result_v3); +// TYPE: (*_result_v0), +// TYPE: (*_result_v1), +// TYPE: (*_result_v2), +// TYPE: (*_result_v3)); // TYPE: } // TYPE: void TestEType::print(::mlir::AsmPrinter &odsPrinter) const { @@ -535,3 +535,19 @@ let mnemonic = "type_j"; let assemblyFormat = "custom($a) custom($b, ref($a))"; } + +// TYPE: ::mlir::Type TestMType::parse +// TYPE: FailureOr _result_a +// TYPE: return TestMType::get +// TYPE: static_cast((*_result_a)) + +def ConvertFromStorageParameter : TypeParameter<"int", ""> { + let cppStorageType = "float"; + let convertFromStorage = "static_cast($_self)"; +} + +def TypeK : TestType<"TestM"> { + let parameters = (ins ConvertFromStorageParameter:$a); + let mnemonic = "type_k"; + let assemblyFormat = "$a"; +} diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp @@ -297,18 +297,21 @@ } for (const AttrOrTypeParameter ¶m : params) { os << ",\n "; + std::string paramSelfStr; + llvm::raw_string_ostream selfOs(paramSelfStr); if (param.isOptional()) { - os << formatv("_result_{0}.value_or(", param.getName()); + selfOs << formatv("(_result_{0}.value_or(", param.getName()); if (Optional defaultValue = param.getDefaultValue()) - os << tgfmt(*defaultValue, &ctx); + selfOs << tgfmt(*defaultValue, &ctx); else - os << param.getCppStorageType() << "()"; - os << ")"; + selfOs << param.getCppStorageType() << "()"; + selfOs << "))"; } else if (isa(param)) { - os << tgfmt("$_type", &ctx); + selfOs << tgfmt("$_type", &ctx); } else { - os << formatv("*_result_{0}", param.getName()); + selfOs << formatv("(*_result_{0})", param.getName()); } + os << tgfmt(param.getConvertFromStorage(), &ctx.withSelf(selfOs.str())); } os << ");"; }