diff --git a/mlir/include/mlir/TableGen/Attribute.h b/mlir/include/mlir/TableGen/Attribute.h --- a/mlir/include/mlir/TableGen/Attribute.h +++ b/mlir/include/mlir/TableGen/Attribute.h @@ -64,8 +64,8 @@ // Returns the template that can be used to produce an instance of the // attribute. - // Syntax: {0} should be replaced with a builder, {1} should be replaced with - // the constant value. + // Syntax: `$builder` should be replaced with a builder, `$0` should be + // replaced with the constant value. StringRef getConstBuilderTemplate() const; // Returns the base-level attribute that this attribute constraint is diff --git a/mlir/test/mlir-tblgen/op-attribute.td b/mlir/test/mlir-tblgen/op-attribute.td --- a/mlir/test/mlir-tblgen/op-attribute.td +++ b/mlir/test/mlir-tblgen/op-attribute.td @@ -186,6 +186,11 @@ // DECL-LABEL: DOp declarations // DECL: static void build({{.*}}, APInt i32_attr, APFloat f64_attr, StringRef str_attr, bool bool_attr, ::SomeI32Enum enum_attr, APInt dv_i32_attr, APFloat dv_f64_attr, StringRef dv_str_attr = "abc", bool dv_bool_attr = true, ::SomeI32Enum dv_enum_attr = ::SomeI32Enum::case5) +// DEF-LABEL: DOp definitions +// DEF: odsState.addAttribute("str_attr", (*odsBuilder).getStringAttr(str_attr)); +// DEF: odsState.addAttribute("dv_str_attr", (*odsBuilder).getStringAttr(dv_str_attr)); + + // Test that only default valued attributes at the end of the arguments // list get default values in the builder signature // --- diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Support/STLExtras.h" +#include "mlir/Support/StringExtras.h" #include "mlir/TableGen/Format.h" #include "mlir/TableGen/GenInfo.h" #include "mlir/TableGen/OpClass.h" @@ -91,6 +92,29 @@ // Utility structs and functions //===----------------------------------------------------------------------===// +/// Splits the given `input` string using the given `delimiter` and writes the +/// segments into `outputs`. `delimiter` will be treated as a whole for +/// splitting. If a segment is empty, it will be omitted. +static void splitString(StringRef input, StringRef delimiter, + llvm::SmallVectorImpl &outputs) { + outputs.clear(); + + // Location in input to start scanning for the next delimiter. + size_t scanLoc = 0; + // Location in input for the next delimiter. + size_t splitLoc = StringRef::npos; + + while ((splitLoc = input.find(delimiter, scanLoc)) != StringRef::npos) { + if (scanLoc < splitLoc) + outputs.push_back(input.substr(scanLoc, splitLoc - scanLoc)); + scanLoc = splitLoc + delimiter.size(); + } + + // Push the last fragment if not empty. + if (scanLoc < input.size()) + outputs.push_back(input.substr(scanLoc, input.size() - scanLoc)); +} + // Returns whether the record has a value of the given name that can be returned // via getValueAsString. static inline bool hasStringAttribute(const Record &record, @@ -968,8 +992,20 @@ // instance. FmtContext fctx; fctx.withBuilder("(*odsBuilder)"); - std::string value = - tgfmt(attr.getConstBuilderTemplate(), &fctx, namedAttr.name); + + std::string builderTemplate = attr.getConstBuilderTemplate(); + + // For StringAttr, its constant builder call will wrap the input in + // quotes, which is correct for normal string literals, but incorrect + // here given we use function arguments. So we need to strip the + // wrapping quotes. + if (StringRef(builderTemplate).contains("\"$0\"")) { + SmallVector segments; + splitString(builderTemplate, "\"$0\"", segments); + builderTemplate = llvm::join(segments, "$0"); + } + + std::string value = tgfmt(builderTemplate, &fctx, namedAttr.name); body << formatv(" {0}.addAttribute(\"{1}\", {2});\n", builderOpState, namedAttr.name, value); } else {