diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -259,12 +259,14 @@ class Type : TypeConstraint { string typeDescription = ""; + string builderCall = ""; } // Allows providing an alternative name and description to an existing type def. class TypeAlias : Type { let typeDescription = t.typeDescription; + let builderCall = t.builderCall; } // A type of a specific dialect. @@ -289,7 +291,6 @@ // making some Types and some Attrs buildable. class BuildableType { // The builder call to invoke (if specified) to construct the BuildableType. - // Format: this will be affixed to the builder. code builderCall = builder; } @@ -313,13 +314,13 @@ // Index type. def Index : Type()">, "index">, - BuildableType<"getIndexType()">; + BuildableType<"$_builder.getIndexType()">; // Integer type of a specific width. class I : Type, width # "-bit integer">, - BuildableType<"getIntegerType(" # width # ")"> { + BuildableType<"$_builder.getIntegerType(" # width # ")"> { int bitwidth = width; } @@ -342,7 +343,7 @@ class F : Type, width # "-bit float">, - BuildableType<"getF" # width # "Type()"> { + BuildableType<"$_builder.getF" # width # "Type()"> { int bitwidth = width; } @@ -355,7 +356,7 @@ def F64 : F<64>; def BF16 : Type, "bfloat16 type">, - BuildableType<"getBF16Type()">; + BuildableType<"$_builder.getBF16Type()">; class Complex : Type allowedTypes> : TensorRankOf; // Unranked Memref type -def AnyUnrankedMemRef : - ShapedContainerType<[AnyType], +def AnyUnrankedMemRef : + ShapedContainerType<[AnyType], IsUnrankedMemRefTypePred, "unranked.memref">; // Memref type. @@ -685,7 +686,7 @@ class TypedAttrBase : Attr { - let constBuilderCall = "$_builder.get" # attrKind # "($_builder." # + let constBuilderCall = "$_builder.get" # attrKind # "(" # attrValType.builderCall # ", $0)"; let storageType = attrKind; } diff --git a/mlir/lib/TableGen/Type.cpp b/mlir/lib/TableGen/Type.cpp --- a/mlir/lib/TableGen/Type.cpp +++ b/mlir/lib/TableGen/Type.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/TableGen/Type.h" +#include "mlir/ADT/TypeSwitch.h" #include "llvm/TableGen/Record.h" using namespace mlir; @@ -36,9 +37,16 @@ if (isVariadic()) baseType = baseType->getValueAsDef("baseType"); - if (!baseType->isSubClassOf("BuildableType")) - return None; - return baseType->getValueAsString("builderCall"); + // Check to see if this type constraint has a builder call. + const llvm::RecordVal *builderCall = baseType->getValue("builderCall"); + if (!builderCall || !builderCall->getValue()) + return llvm::None; + return TypeSwitch>(builderCall->getValue()) + .Case([&](auto *init) { + StringRef value = init->getValue(); + return value.empty() ? Optional() : value; + }) + .Default([](auto *) { return llvm::None; }); } Type::Type(const llvm::Record *record) : TypeConstraint(record) {} diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -410,9 +410,15 @@ void OperationFormat::genParserTypeResolution(Operator &op, OpMethodBody &body) { // Initialize the set of buildable types. - for (auto &it : buildableTypes) - body << " Type odsBuildableType" << it.second << " = parser.getBuilder()." - << it.first << ";\n"; + if (!buildableTypes.empty()) { + body << " Builder &builder = parser.getBuilder();\n"; + + FmtContext typeBuilderCtx; + typeBuilderCtx.withBuilder("builder"); + for (auto &it : buildableTypes) + body << " Type odsBuildableType" << it.second << " = " + << tgfmt(it.first, &typeBuilderCtx) << ";\n"; + } // Resolve each of the result types. if (allResultTypes) {