diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -928,6 +928,11 @@ let baseAttr = attr; } +// Default-valued string-based attribute. Wraps the default value in escaped +// quotes. +class DefaultValuedStrAttr + : DefaultValuedAttr; + //===----------------------------------------------------------------------===// // Primitive attribute kinds @@ -1095,7 +1100,7 @@ // An attribute backed by a string type. class StringBasedAttr : Attr { - let constBuilderCall = "$_builder.getStringAttr(\"$0\")"; + let constBuilderCall = "$_builder.getStringAttr($0)"; let storageType = [{ ::mlir::StringAttr }]; let returnType = [{ ::llvm::StringRef }]; let valueType = NoneType; @@ -1672,6 +1677,10 @@ def ConstBoolAttrTrue : ConstantAttr; def ConstUnitAttr : ConstantAttr; +// Constant string-based attribute. Wraps the desired string in escaped quotes. +class ConstantStrAttr + : ConstantAttr; + //===----------------------------------------------------------------------===// // Common attribute constraints //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2343,4 +2343,21 @@ }]; } +//===----------------------------------------------------------------------===// +// Test Ops with Default-Valued String Attributes +//===----------------------------------------------------------------------===// + +def TestDefaultStrAttrNoValueOp : TEST_Op<"no_str_value"> { + let arguments = (ins DefaultValuedAttr:$value); + let assemblyFormat = "attr-dict"; +} + +def TestDefaultStrAttrHasValueOp : TEST_Op<"has_str_value"> { + let arguments = (ins DefaultValuedStrAttr:$value); + let assemblyFormat = "attr-dict"; +} + +def : Pat<(TestDefaultStrAttrNoValueOp $value), + (TestDefaultStrAttrHasValueOp ConstantStrAttr)>; + #endif // TEST_OPS diff --git a/mlir/test/mlir-tblgen/constant-str-attr-invalid.mlir b/mlir/test/mlir-tblgen/constant-str-attr-invalid.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-tblgen/constant-str-attr-invalid.mlir @@ -0,0 +1,4 @@ +// RUN: mlir-opt -verify-diagnostics %s + +// Test DefaultValuedAttr is recognized as "no default value" +test.no_str_value {} // expected-error {{'test.no_str_value' op requires attribute 'value'}} diff --git a/mlir/test/mlir-tblgen/op-attribute.td b/mlir/test/mlir-tblgen/op-attribute.td --- a/mlir/test/mlir-tblgen/op-attribute.td +++ b/mlir/test/mlir-tblgen/op-attribute.td @@ -337,7 +337,7 @@ SomeI32Enum:$enum_attr, DefaultValuedAttr:$dv_i32_attr, DefaultValuedAttr:$dv_f64_attr, - DefaultValuedAttr:$dv_str_attr, + DefaultValuedStrAttr:$dv_str_attr, DefaultValuedAttr:$dv_bool_attr, DefaultValuedAttr:$dv_enum_attr ); @@ -377,7 +377,7 @@ F64Attr:$f64_attr, DefaultValuedAttr:$dv_f64_attr, StrAttr:$str_attr, - DefaultValuedAttr:$dv_str_attr, + DefaultValuedStrAttr:$dv_str_attr, BoolAttr:$bool_attr, DefaultValuedAttr:$dv_bool_attr, SomeI32Enum:$enum_attr, diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir --- a/mlir/test/mlir-tblgen/op-format.mlir +++ b/mlir/test/mlir-tblgen/op-format.mlir @@ -361,3 +361,10 @@ // CHECK: test.format_infer_type %ignored_res7 = test.format_infer_type + +//===----------------------------------------------------------------------===// +// Check DefaultValuedStrAttr +//===----------------------------------------------------------------------===// + +// CHECK: test.has_str_value +test.has_str_value {} diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir --- a/mlir/test/mlir-tblgen/pattern.mlir +++ b/mlir/test/mlir-tblgen/pattern.mlir @@ -583,3 +583,13 @@ // CHECK: "test.two_to_one"(%0, %1) : (i32, i32) -> i1 return %0 : i1 } + +//===----------------------------------------------------------------------===// +// Test that patterns can create ConstantStrAttr +//===----------------------------------------------------------------------===// + +func @testConstantStrAttr() -> () { + // CHECK: test.has_str_value {value = "foo"} + test.no_str_value {value = "bar"} + return +} diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -1710,12 +1710,7 @@ std::string defaultValue; if (attrParamKind == AttrParamKind::UnwrappedValue && i >= defaultValuedAttrStartIndex) { - bool isString = attr.getReturnType() == "::llvm::StringRef"; - if (isString) - defaultValue.append("\""); defaultValue += attr.getDefaultValue(); - if (isString) - defaultValue.append("\""); } paramList.emplace_back(type, namedAttr.name, defaultValue, properties); } diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -50,6 +50,13 @@ }; } // end namespace llvm +static std::string escapeString(StringRef value) { + std::string ret; + llvm::raw_string_ostream os(ret); + llvm::printEscapedString(value, os); + return os.str(); +} + //===----------------------------------------------------------------------===// // PatternEmitter //===----------------------------------------------------------------------===// @@ -189,7 +196,7 @@ // Returns the C++ expression to construct a constant attribute of the given // `value` for the given attribute kind `attr`. - std::string handleConstantAttr(Attribute attr, StringRef value); + std::string handleConstantAttr(Attribute attr, const Twine &value); // Returns the C++ expression to build an argument from the given DAG `leaf`. // `patArgName` is used to bound the argument to the source pattern. @@ -313,7 +320,7 @@ } std::string PatternEmitter::handleConstantAttr(Attribute attr, - StringRef value) { + const Twine &value) { if (!attr.isConstBuildable()) PrintFatalError(loc, "Attribute " + attr.getAttrDefName() + " does not have the 'constBuilderCall' field"); @@ -492,7 +499,8 @@ formatv("\"operand {0} of native code call '{1}' failed to satisfy " "constraint: " "'{2}'\"", - i, tree.getNativeCodeTemplate(), constraint.getSummary())); + i, tree.getNativeCodeTemplate(), + escapeString(constraint.getSummary()))); } LLVM_DEBUG(llvm::dbgs() << "done emitting match for native code call\n"); @@ -630,7 +638,7 @@ formatv("\"operand {0} of op '{1}' failed to satisfy constraint: " "'{2}'\"", operand - op.operand_begin(), op.getOperationName(), - constraint.getSummary())); + escapeString(constraint.getSummary()))); } } @@ -694,9 +702,9 @@ opName, tgfmt(matcher.getConditionTemplate(), &fmtCtx.withSelf("tblgen_attr")), formatv("\"op '{0}' attribute '{1}' failed to satisfy constraint: " - "{2}\"", + "'{2}'\"", op.getOperationName(), namedAttr->name, - matcher.getAsConstraint().getSummary())); + escapeString(matcher.getAsConstraint().getSummary()))); } // Capture the value @@ -740,8 +748,8 @@ symbolInfoMap.getValueAndRangeUse(entities.front())); emitMatchCheck( opName, tgfmt(condition, &fmtCtx.withSelf(self.str())), - formatv("\"value entity '{0}' failed to satisfy constraint: {1}\"", - entities.front(), constraint.getSummary())); + formatv("\"value entity '{0}' failed to satisfy constraint: '{1}'\"", + entities.front(), escapeString(constraint.getSummary()))); } else if (isa(constraint)) { PrintFatalError( @@ -765,9 +773,9 @@ tgfmt(condition, &fmtCtx.withSelf(self), names[0], names[1], names[2], names[3]), formatv("\"entities '{0}' failed to satisfy constraint: " - "{1}\"", + "'{1}'\"", llvm::join(entities, ", "), - constraint.getSummary())); + escapeString(constraint.getSummary()))); } } @@ -1103,7 +1111,7 @@ if (leaf.isEnumAttrCase()) { auto enumCase = leaf.getAsEnumAttrCase(); if (enumCase.isStrCase()) - return handleConstantAttr(enumCase, enumCase.getSymbol()); + return handleConstantAttr(enumCase, "\"" + enumCase.getSymbol() + "\""); // This is an enum case backed by an IntegerAttr. We need to get its value // to build the constant. std::string val = std::to_string(enumCase.getValue());