diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -1274,6 +1274,14 @@ getZeroValue() const { return getZeroAPFloat(); } + + /// Get a zero for a StringRef. + template + typename std::enable_if::value, T>::type + getZeroValue() const { + return StringRef(); + } + /// Get a zero for an C++ integer or float type. template typename std::enable_if::is_integer || diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -976,6 +976,11 @@ /// Print a dense string elements attribute. void printDenseStringElementsAttr(DenseStringElementsAttr attr); + /// Print a dense elements attribute. If 'allowHex' is true, a hex string is + /// used instead of individual elements when the elements attr is large. + void printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr, + bool allowHex); + void printDialectAttribute(Attribute attr); void printDialectType(Type type); @@ -1396,13 +1401,13 @@ break; } case StandardAttributes::DenseIntOrFPElements: { - auto eltsAttr = attr.cast(); + auto eltsAttr = attr.cast(); if (printerFlags.shouldElideElementsAttr(eltsAttr)) { printElidedElementsAttr(os); break; } os << "dense<"; - printDenseElementsAttr(eltsAttr, /*allowHex=*/true); + printDenseIntOrFPElementsAttr(eltsAttr, /*allowHex=*/true); os << '>'; break; } @@ -1425,7 +1430,8 @@ break; } os << "sparse<"; - printDenseElementsAttr(elementsAttr.getIndices(), /*allowHex=*/false); + printDenseIntOrFPElementsAttr(elementsAttr.getIndices(), + /*allowHex=*/false); os << ", "; printDenseElementsAttr(elementsAttr.getValues(), /*allowHex=*/true); os << '>'; @@ -1477,6 +1483,17 @@ void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr, bool allowHex) { + if (auto stringAttr = attr.dyn_cast()) { + printDenseStringElementsAttr(stringAttr); + return; + } + + printDenseIntOrFPElementsAttr(attr.cast(), + allowHex); +} + +void ModulePrinter::printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr, + bool allowHex) { auto type = attr.getType(); auto shape = type.getShape(); auto rank = type.getRank(); diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -764,6 +764,11 @@ "foof320"(){bar = sparse<[], []> : tensor<0xf32>} : () -> () // CHECK: "foof321"() {bar = sparse<{{\[}}], {{\[}}]> : tensor} : () -> () "foof321"(){bar = sparse<[], []> : tensor} : () -> () + +// CHECK: "foostr"() {bar = sparse<0, "foo"> : tensor<1x1x1x!unknown<"">>} : () -> () + "foostr"(){bar = sparse<0, "foo"> : tensor<1x1x1x!unknown<"">>} : () -> () +// CHECK: "foostr"() {bar = sparse<{{\[\[}}1, 1, 0], {{\[}}0, 1, 0], {{\[}}0, 0, 1]], {{\[}}"a", "b", "c"]> : tensor<2x2x2x!unknown<"">>} : () -> () + "foostr"(){bar = sparse<[[1, 1, 0], [0, 1, 0], [0, 0, 1]], ["a", "b", "c"]> : tensor<2x2x2x!unknown<"">>} : () -> () return }