diff --git a/mlir/include/mlir/IR/AttrTypeBase.td b/mlir/include/mlir/IR/AttrTypeBase.td --- a/mlir/include/mlir/IR/AttrTypeBase.td +++ b/mlir/include/mlir/IR/AttrTypeBase.td @@ -96,30 +96,38 @@ // This is necessary because the `body` is also used to generate `getChecked` // methods, which have a different underlying `Base::get*` call. // -class AttrOrTypeBuilder { +class AttrOrTypeBuilder { dag dagParams = parameters; code body = bodyCode; + // Change the return type of the builder. By default, it is the type of the + // attribute or type. + string returnType = returnTypeStr; + // The context parameter can be inferred from one of the other parameters and // is not implicitly added to the parameter list. bit hasInferredContextParam = 0; } -class AttrBuilder - : AttrOrTypeBuilder; -class TypeBuilder - : AttrOrTypeBuilder; +class AttrBuilder + : AttrOrTypeBuilder; +class TypeBuilder + : AttrOrTypeBuilder; // A class of AttrOrTypeBuilder that is able to infer the MLIRContext parameter // from one of the other builder parameters. Instances of this builder do not // have `MLIRContext *` implicitly added to the parameter list. -class AttrOrTypeBuilderWithInferredContext - : TypeBuilder { +class AttrOrTypeBuilderWithInferredContext + : TypeBuilder { let hasInferredContextParam = 1; } -class AttrBuilderWithInferredContext - : AttrOrTypeBuilderWithInferredContext; -class TypeBuilderWithInferredContext - : AttrOrTypeBuilderWithInferredContext; +class AttrBuilderWithInferredContext + : AttrOrTypeBuilderWithInferredContext; +class TypeBuilderWithInferredContext + : AttrOrTypeBuilderWithInferredContext; //===----------------------------------------------------------------------===// // Definitions diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -151,10 +151,9 @@ std::enable_if_t::value> *sfinae = nullptr> void printStrippedAttrOrType(ArrayRef attrOrTypes) { - llvm::interleaveComma(attrOrTypes, getStream(), - [this](AttrOrType attrOrType) { - printStrippedAttrOrType(attrOrType); - }); + llvm::interleaveComma( + attrOrTypes, getStream(), + [this](AttrOrType attrOrType) { printStrippedAttrOrType(attrOrType); }); } /// SFINAE for printing the provided attribute in the context of an operation @@ -793,14 +792,14 @@ /// unlike `OpBuilder::getType`, this method does not implicitly insert a /// context parameter. template - T getChecked(SMLoc loc, ParamsT &&...params) { + auto getChecked(SMLoc loc, ParamsT &&...params) { return T::getChecked([&] { return emitError(loc); }, std::forward(params)...); } /// A variant of `getChecked` that uses the result of `getNameLoc` to emit /// errors. template - T getChecked(ParamsT &&...params) { + auto getChecked(ParamsT &&...params) { return T::getChecked([&] { return emitError(getNameLoc()); }, std::forward(params)...); } diff --git a/mlir/include/mlir/TableGen/AttrOrTypeDef.h b/mlir/include/mlir/TableGen/AttrOrTypeDef.h --- a/mlir/include/mlir/TableGen/AttrOrTypeDef.h +++ b/mlir/include/mlir/TableGen/AttrOrTypeDef.h @@ -37,6 +37,9 @@ public: using Builder::Builder; + /// Returns an optional builder return type. + Optional getReturnType() const; + /// Returns true if this builder is able to infer the MLIRContext parameter. bool hasInferredContextParameter() const; }; diff --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp --- a/mlir/lib/TableGen/AttrOrTypeDef.cpp +++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp @@ -20,7 +20,11 @@ // AttrOrTypeBuilder //===----------------------------------------------------------------------===// -/// Returns true if this builder is able to infer the MLIRContext parameter. +Optional AttrOrTypeBuilder::getReturnType() const { + Optional type = def->getValueAsOptionalString("returnType"); + return type && !type->empty() ? type : llvm::None; +} + bool AttrOrTypeBuilder::hasInferredContextParameter() const { return def->getValueAsBit("hasInferredContextParam"); } @@ -81,14 +85,6 @@ "'assemblyFormat' or 'hasCustomAssemblyFormat' can only be " "used when 'mnemonic' is set"); } - // Assembly format parser requires builders with the same prototype - // as the default-builders. - // TODO: attempt to detect when a custom builder matches the prototype. - if (hasDeclarativeFormat && skipDefaultBuilders()) { - PrintWarning(getLoc(), - "using 'assemblyFormat' with 'skipDefaultBuilders=1' may " - "result in C++ compilation errors"); - } // Assembly format printer requires accessors to be generated. if (hasDeclarativeFormat && !genAccessors()) { PrintFatalError(getLoc(), diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td --- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td +++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td @@ -223,6 +223,19 @@ let assemblyFormat = "`<` $a `>`"; } +// Test overridding attribute builders with a custom builder. +def TestOverrideBuilderAttr : Test_Attr<"TestOverrideBuilder"> { + let mnemonic = "override_builder"; + let parameters = (ins "int":$a); + let assemblyFormat = "`<` $a `>`"; + + let skipDefaultBuilders = 1; + let genVerifyDecl = 1; + let builders = [AttrBuilder<(ins "int":$a), [{ + return ::mlir::IntegerAttr::get(::mlir::IndexType::get($_ctxt), a); + }], "::mlir::Attribute">]; +} + // Test simple extern 1D vector using ElementsAttrInterface. def TestExtern1DI64ElementsAttr : Test_Attr<"TestExtern1DI64Elements", [ ElementsAttrInterface diff --git a/mlir/test/mlir-tblgen/attr-or-type-format.td b/mlir/test/mlir-tblgen/attr-or-type-format.td --- a/mlir/test/mlir-tblgen/attr-or-type-format.td +++ b/mlir/test/mlir-tblgen/attr-or-type-format.td @@ -55,8 +55,8 @@ // ATTR: if (odsParser.parseRParen()) // ATTR: return {}; // ATTR: return TestAAttr::get(odsParser.getContext(), -// ATTR: (*_result_value), -// ATTR: (*_result_complex)); +// ATTR: IntegerAttr((*_result_value)), +// ATTR: TestParamA((*_result_complex))); // ATTR: } // ATTR: void TestAAttr::print(::mlir::AsmPrinter &odsPrinter) const { @@ -114,8 +114,8 @@ // ATTR: return {}; // ATTR: } // ATTR: return TestBAttr::get(odsParser.getContext(), -// ATTR: (*_result_v0), -// ATTR: (*_result_v1)); +// ATTR: TestParamA((*_result_v0)), +// ATTR: TestParamB((*_result_v1))); // ATTR: } // ATTR: void TestBAttr::print(::mlir::AsmPrinter &odsPrinter) const { @@ -151,8 +151,8 @@ // ATTR: if (::mlir::failed(_result_v1)) // ATTR: return {}; // ATTR: return TestFAttr::get(odsParser.getContext(), -// ATTR: (*_result_v0), -// ATTR: (*_result_v1)); +// ATTR: int((*_result_v0)), +// ATTR: int((*_result_v1))); // ATTR: } def AttrC : TestAttr<"TestF"> { @@ -278,10 +278,10 @@ // TYPE: if (::mlir::failed(_result_v3)) // TYPE: return {}; // TYPE: return TestDType::get(odsParser.getContext(), -// TYPE: (*_result_v0), -// TYPE: (*_result_v1), -// TYPE: (*_result_v2), -// TYPE: (*_result_v3)); +// TYPE: TestParamC((*_result_v0)), +// TYPE: TestParamD((*_result_v1)), +// TYPE: TestParamC((*_result_v2)), +// TYPE: TestParamD((*_result_v3))); // TYPE: } // TYPE: void TestDType::print(::mlir::AsmPrinter &odsPrinter) const { @@ -369,10 +369,10 @@ // TYPE: return {}; // TYPE: } // TYPE: return TestEType::get(odsParser.getContext(), -// TYPE: (*_result_v0), -// TYPE: (*_result_v1), -// TYPE: (*_result_v2), -// TYPE: (*_result_v3)); +// TYPE: IntegerAttr((*_result_v0)), +// TYPE: IntegerAttr((*_result_v1)), +// TYPE: IntegerAttr((*_result_v2)), +// TYPE: IntegerAttr((*_result_v3))); // TYPE: } // TYPE: void TestEType::print(::mlir::AsmPrinter &odsPrinter) const { diff --git a/mlir/test/mlir-tblgen/attrdefs.td b/mlir/test/mlir-tblgen/attrdefs.td --- a/mlir/test/mlir-tblgen/attrdefs.td +++ b/mlir/test/mlir-tblgen/attrdefs.td @@ -31,9 +31,9 @@ // DEF-NEXT: .Case(::test::IndexAttr::getMnemonic() // DEF-NEXT: value = ::test::IndexAttr::parse(parser, type); // DEF-NEXT: return ::mlir::success(!!value); -// DEF: .Default([&](llvm::StringRef keyword, +// DEF: .Default([&](llvm::StringRef keyword, // DEF-NEXT: *mnemonic = keyword; -// DEF-NEXT: return llvm::None; +// DEF-NEXT: return llvm::None; def Test_Dialect: Dialect { // DECL-NOT: TestDialect @@ -148,3 +148,13 @@ // DEF: ParamWithAccessorTypeAttrStorage // DEF: ParamWithAccessorTypeAttrStorage(std::string param) // DEF: StringRef ParamWithAccessorTypeAttr::getParam() + +def G_BuilderWithReturnTypeAttr : TestAttr<"BuilderWithReturnType"> { + let parameters = (ins "int":$a); + let genVerifyDecl = 1; + let builders = [AttrBuilder<(ins), [{ return {}; }], "::mlir::Attribute">]; +} + +// DECL-LABEL: class BuilderWithReturnTypeAttr +// DECL: ::mlir::Attribute get( +// DECL: ::mlir::Attribute getChecked( diff --git a/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir b/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir --- a/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir +++ b/mlir/test/mlir-tblgen/testdialect-attrdefs.mlir @@ -13,3 +13,9 @@ // CHECK-LABEL: @qualifiedAttr() // CHECK-SAME: #test.cmpnd_nested_outer_qual>> func.func private @qualifiedAttr() attributes {foo = #test.cmpnd_nested_outer_qual>>} + +// CHECK-LABEL: @overriddenAttr +// CHECK-SAME: foo = 5 : index +func.func private @overriddenAttr() attributes { + foo = #test.override_builder<5> +} diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp @@ -348,7 +348,10 @@ void DefGen::emitCustomBuilder(const AttrOrTypeBuilder &builder) { // Don't emit a body if there isn't one. auto props = builder.getBody() ? Method::Static : Method::StaticDeclaration; - Method *m = defCls.addMethod(def.getCppClassName(), "get", props, + StringRef returnType = def.getCppClassName(); + if (Optional builderReturnType = builder.getReturnType()) + returnType = *builderReturnType; + Method *m = defCls.addMethod(returnType, "get", props, getCustomBuilderParams({}, builder)); if (!builder.getBody()) return; @@ -373,8 +376,11 @@ void DefGen::emitCheckedCustomBuilder(const AttrOrTypeBuilder &builder) { // Don't emit a body if there isn't one. auto props = builder.getBody() ? Method::Static : Method::StaticDeclaration; + StringRef returnType = def.getCppClassName(); + if (Optional builderReturnType = builder.getReturnType()) + returnType = *builderReturnType; Method *m = defCls.addMethod( - def.getCppClassName(), "getChecked", props, + returnType, "getChecked", props, getCustomBuilderParams( {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"}}, builder)); diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp --- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp @@ -311,7 +311,9 @@ } else { selfOs << formatv("(*_result_{0})", param.getName()); } - os << tgfmt(param.getConvertFromStorage(), &ctx.withSelf(selfOs.str())); + os << param.getCppType() << "(" + << tgfmt(param.getConvertFromStorage(), &ctx.withSelf(selfOs.str())) + << ")"; } os << ");"; }