diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -460,9 +460,6 @@ let assemblyFormat = "(`stable` $stable^)? $n" "`,`$xs (`jointly` $ys^)? attr-dict" "`:` type($xs) (`jointly` type($ys)^)?"; - let builders = [ - OpBuilder<(ins "Value":$n, "ValueRange":$xs, "ValueRange":$ys)> - ]; let hasVerifier = 1; } diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1246,9 +1246,10 @@ // "true" if the attribute is present and "false" otherwise. def UnitAttr : Attr()">, "unit attribute"> { let storageType = [{ ::mlir::UnitAttr }]; - let constBuilderCall = "$_builder.getUnitAttr()"; + let constBuilderCall = "(($0) ? $_builder.getUnitAttr() : nullptr)"; let convertFromStorage = "$_self != nullptr"; let returnType = "bool"; + let defaultValue = "false"; let valueType = NoneType; let isOptional = 1; } @@ -1575,7 +1576,7 @@ class ConstF32Attr : ConstantAttr; def ConstBoolAttrFalse : ConstantAttr; def ConstBoolAttrTrue : ConstantAttr; -def ConstUnitAttr : ConstantAttr; +def ConstUnitAttr : ConstantAttr; // Constant string-based attribute. Wraps the desired string in escaped quotes. class ConstantStrAttr diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -706,11 +706,6 @@ return success(); } -void SortOp::build(OpBuilder &odsBuilder, OperationState &odsState, Value n, - ValueRange xs, ValueRange ys) { - build(odsBuilder, odsState, n, xs, ys, /*stable=*/false); -} - LogicalResult SortOp::verify() { if (getXs().empty()) return emitError("need at least one xs buffer."); diff --git a/mlir/test/mlir-tblgen/op-attribute.td b/mlir/test/mlir-tblgen/op-attribute.td --- a/mlir/test/mlir-tblgen/op-attribute.td +++ b/mlir/test/mlir-tblgen/op-attribute.td @@ -488,6 +488,11 @@ // DEF-NEXT: (*this)->removeAttr(getAttrAttrName()); // DEF: build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, /*optional*/::mlir::UnitAttr attr) +// DEF: build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, /*optional*/bool attr) + +// DECL-LABEL: UnitAttrOp declarations +// DECL-NOT: declarations +// DECL: build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, /*optional*/bool attr = false) // Test elementAttr field of TypedArrayAttr. diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -1635,9 +1635,9 @@ } void OpEmitter::genPopulateDefaultAttributes() { - // All done if no attributes have default values. + // All done if no attributes, except optional ones, have default values. if (llvm::all_of(op.getAttributes(), [](const NamedAttribute &named) { - return !named.attr.hasDefaultValue(); + return !named.attr.hasDefaultValue() || named.attr.isOptional(); })) return; @@ -1667,8 +1667,8 @@ fctx.withBuilder(odsBuilder); std::string defaultValue = std::string( tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue())); - body.indent() << formatv(" attributes.append(attrNames[{0}], {1});\n", - index, defaultValue); + body.indent() << formatv("attributes.append(attrNames[{0}], {1});\n", index, + defaultValue); body.unindent() << "}\n"; } } @@ -2143,12 +2143,16 @@ if (attr.isDerivedAttr() || inferredAttributes.contains(namedAttr.name)) continue; - // TODO(jpienaar): The wrapping of optional is different for default or not, - // so don't unwrap for default ones that would fail below. - bool emitNotNullCheck = (attr.isOptional() && !attr.hasDefaultValue()) || - (attr.hasDefaultValue() && !isRawValueAttr); + // TODO: The wrapping of optional is different for default or not, so don't + // unwrap for default ones that would fail below. + bool emitNotNullCheck = + (attr.isOptional() && !attr.hasDefaultValue()) || + (attr.hasDefaultValue() && !isRawValueAttr) || + // TODO: UnitAttr is optional, not wrapped, but needs to be guarded as + // the constant materialization is only for true case. + (isRawValueAttr && attr.getAttrDefName() == "UnitAttr"); if (emitNotNullCheck) - body << formatv(" if ({0}) ", namedAttr.name) << "{\n"; + body.indent() << formatv("if ({0}) ", namedAttr.name) << "{\n"; if (isRawValueAttr && canUseUnwrappedRawValue(attr)) { // If this is a raw value, then we need to wrap it in an Attribute @@ -2175,7 +2179,7 @@ namedAttr.name); } if (emitNotNullCheck) - body << " }\n"; + body.unindent() << " }\n"; } // Create the correct number of regions. @@ -2966,7 +2970,7 @@ // call. This should be set instead. std::string defaultValue = std::string( tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue())); - body << " if (!attr)\n attr = " << defaultValue << ";\n"; + body << "if (!attr)\n attr = " << defaultValue << ";\n"; } body << "return attr;\n"; };