diff --git a/mlir/test/mlir-tblgen/op-attribute.td b/mlir/test/mlir-tblgen/op-attribute.td --- a/mlir/test/mlir-tblgen/op-attribute.td +++ b/mlir/test/mlir-tblgen/op-attribute.td @@ -412,7 +412,7 @@ // DEF: attributes.append(attrNames[1], odsBuilder.getDictionaryAttr(getDefaultDictAttrs(odsBuilder))); // DECL-LABEL: DefaultDictAttrOp declarations -// DECL: build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::DictionaryAttr empty, ::mlir::DictionaryAttr non_empty) +// DECL: build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::DictionaryAttr empty = nullptr, ::mlir::DictionaryAttr non_empty = nullptr) // Test derived type attr. diff --git a/mlir/test/mlir-tblgen/op-default-builder.td b/mlir/test/mlir-tblgen/op-default-builder.td new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-tblgen/op-default-builder.td @@ -0,0 +1,71 @@ +// RUN: mlir-tblgen -gen-op-decls -I %S/../../include %s | FileCheck %s + +include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/EnumAttr.td" +include "mlir/IR/OpBase.td" + +def Test_Dialect : Dialect { + let name = "test"; + let cppNamespace = "foobar"; +} +class NS_Op traits> : + Op; + +def SomeAttr : Attr, "some attribute kind"> { + let storageType = "some-attr-kind"; + let returnType = "some-return-type"; + let convertFromStorage = "$_self.some-convert-from-storage()"; + let constBuilderCall = "some-const-builder-call($_builder, $0)"; +} + +def AOp : NS_Op<"a_op", []> { + let arguments = (ins + FloatLike:$lhs, + SomeAttr:$aAttr, + DefaultValuedAttr:$bAttr, + OptionalAttr:$cAttr, + DefaultValuedOptionalAttr:$dAttr + ); +} + +// CHECK-LABEL: AOp declarations +// Note: `cAttr` below could be conditionally optional and so the generation is +// currently conservative. +// CHECK-DAG: ::mlir::Value lhs, some-attr-kind aAttr, some-attr-kind bAttr, /*optional*/some-attr-kind cAttr, /*optional*/some-attr-kind dAttr); +// CHECK-DAG: ::mlir::Value lhs, some-return-type aAttr, some-return-type bAttr, /*optional*/some-attr-kind cAttr, /*optional*/some-return-type dAttr = 7.2); +// CHECK-DAG: ::mlir::TypeRange resultTypes, ::mlir::Value lhs, some-attr-kind aAttr, some-attr-kind bAttr, /*optional*/some-attr-kind cAttr, /*optional*/some-attr-kind dAttr); +// CHECK-DAG: ::mlir::TypeRange resultTypes, ::mlir::Value lhs, some-return-type aAttr, some-return-type bAttr, /*optional*/some-attr-kind cAttr, /*optional*/some-return-type dAttr = 7.2); + +def BOp : NS_Op<"b_op", []> { + let arguments = (ins + DefaultValuedAttr:$aAttr, + DefaultValuedAttr:$bAttr + ); +} + +// Verify that non-overlapping builders created where all could be elided. +// CHECK-LABEL: BOp declarations +// CHECK-DAG: some-attr-kind aAttr, some-attr-kind bAttr = nullptr); +// CHECK-DAG: some-return-type aAttr = 6.2, some-return-type bAttr = 4.2); +// CHECK-DAG: ::mlir::TypeRange resultTypes, some-attr-kind aAttr, some-attr-kind bAttr = nullptr); +// CHECK-DAG: ::mlir::TypeRange resultTypes, some-return-type aAttr = 6.2, some-return-type bAttr = 4.2); + +def COp : NS_Op<"c_op", []> { + let arguments = (ins + FloatLike:$value, + OptionalAttr:$ag, + OptionalAttr:$as, + OptionalAttr:$nos, + OptionalAttr:$al, + UnitAttr:$vo, + UnitAttr:$non + ); +} + +// CHECK-LABEL: COp declarations +// Note: `al` below could be conditionally optional and so the generation is +// currently conservative. +// CHECK-DAG: ::mlir::Value value, /*optional*/::mlir::ArrayAttr ag, /*optional*/::mlir::ArrayAttr as, /*optional*/::mlir::ArrayAttr nos, /*optional*/::mlir::IntegerAttr al, /*optional*/::mlir::UnitAttr vo, /*optional*/::mlir::UnitAttr non = nullptr); +// CHECK-DAG: ::mlir::Value value, /*optional*/::mlir::ArrayAttr ag, /*optional*/::mlir::ArrayAttr as, /*optional*/::mlir::ArrayAttr nos, /*optional*/::mlir::IntegerAttr al, /*optional*/bool vo = false, /*optional*/bool non = false); +// CHECK-DAG: ::mlir::TypeRange resultTypes, ::mlir::Value value, /*optional*/::mlir::ArrayAttr ag, /*optional*/::mlir::ArrayAttr as, /*optional*/::mlir::ArrayAttr nos, /*optional*/::mlir::IntegerAttr al, /*optional*/::mlir::UnitAttr vo, /*optional*/::mlir::UnitAttr non = nullptr); +// CHECK-DAG: ::mlir::TypeRange resultTypes, ::mlir::Value value, /*optional*/::mlir::ArrayAttr ag, /*optional*/::mlir::ArrayAttr as, /*optional*/::mlir::ArrayAttr nos, /*optional*/::mlir::IntegerAttr al, /*optional*/bool vo = false, /*optional*/bool non = false); diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -2883,16 +2883,20 @@ // Successors and variadic regions go at the end of the parameter list, so no // default arguments are possible. bool hasTrailingParams = op.getNumSuccessors() || op.getNumVariadicRegions(); - if (attrParamKind == AttrParamKind::UnwrappedValue && !hasTrailingParams) { + if (!hasTrailingParams) { // Calculate the start index from which we can attach default values in the // builder declaration. for (int i = op.getNumArgs() - 1; i >= 0; --i) { auto *namedAttr = llvm::dyn_cast_if_present(op.getArg(i)); - if (!namedAttr || !namedAttr->attr.hasDefaultValue()) + if (!namedAttr) break; - if (!canUseUnwrappedRawValue(namedAttr->attr)) + Attribute attr = namedAttr->attr; + // TODO: Currently we can't differentiate between optional meaning do not + // verify/not always error if missing or optional meaning need not be + // specified in builder. Expand isOptional once we can differentiate. + if (!attr.hasDefaultValue() && !attr.isDerivedAttr()) break; // Creating an APInt requires us to provide bitwidth, value, and @@ -2907,6 +2911,21 @@ defaultValuedAttrStartIndex = i; } } + // Avoid generating build methods that are ambiguous due to default values by + // requiring at least one attribute. + if (defaultValuedAttrStartIndex < op.getNumArgs()) { + // TODO: This should have been possible as a cast but + // required template instantiations is not yet defined for the tblgen helper + // classes. + auto *namedAttr = + cast(op.getArg(defaultValuedAttrStartIndex)); + Attribute attr = namedAttr->attr; + if ((attrParamKind == AttrParamKind::WrappedAttr && + canUseUnwrappedRawValue(attr)) || + (attrParamKind == AttrParamKind::UnwrappedValue && + !canUseUnwrappedRawValue(attr))) + ++defaultValuedAttrStartIndex; + } /// Collect any inferred attributes. for (const NamedTypeConstraint &operand : op.getOperands()) { @@ -2959,9 +2978,12 @@ // Attach default value if requested and possible. std::string defaultValue; - if (attrParamKind == AttrParamKind::UnwrappedValue && - i >= defaultValuedAttrStartIndex) { - defaultValue += attr.getDefaultValue(); + if (i >= defaultValuedAttrStartIndex) { + if (attrParamKind == AttrParamKind::UnwrappedValue && + canUseUnwrappedRawValue(attr)) + defaultValue += attr.getDefaultValue(); + else + defaultValue += "nullptr"; } paramList.emplace_back(type, namedAttr.name, StringRef(defaultValue), attr.isOptional());