diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -52,7 +52,7 @@ MLIRContext *getContext() const { return context; } - Identifier getIdentifier(StringRef str); + Identifier getIdentifier(const Twine &str); // Locations. Location getUnknownLoc(); @@ -94,7 +94,7 @@ IntegerAttr getIntegerAttr(Type type, const APInt &value); FloatAttr getFloatAttr(Type type, double value); FloatAttr getFloatAttr(Type type, const APFloat &value); - StringAttr getStringAttr(StringRef bytes); + StringAttr getStringAttr(const Twine &bytes); ArrayAttr getArrayAttr(ArrayRef value); FlatSymbolRefAttr getSymbolRefAttr(Operation *value); FlatSymbolRefAttr getSymbolRefAttr(StringRef value); @@ -393,7 +393,7 @@ /// Create an operation of specific op type at the current insertion point. template - OpTy create(Location location, Args &&... args) { + OpTy create(Location location, Args &&...args) { OperationState state(location, OpTy::getOperationName()); if (!state.name.getAbstractOperation()) llvm::report_fatal_error("Building op `" + @@ -411,7 +411,7 @@ /// the results after folding the operation. template void createOrFold(SmallVectorImpl &results, Location location, - Args &&... args) { + Args &&...args) { // Create the operation without using 'createOperation' as we don't want to // insert it yet. OperationState state(location, OpTy::getOperationName()); @@ -433,7 +433,7 @@ template typename std::enable_if(), Value>::type - createOrFold(Location location, Args &&... args) { + createOrFold(Location location, Args &&...args) { SmallVector results; createOrFold(results, location, std::forward(args)...); return results.front(); @@ -443,7 +443,7 @@ template typename std::enable_if(), OpTy>::type - createOrFold(Location location, Args &&... args) { + createOrFold(Location location, Args &&...args) { auto op = create(location, std::forward(args)...); SmallVector unused; tryFold(op.getOperation(), unused); diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -812,16 +812,9 @@ let parameters = (ins StringRefParameter<"">:$value, AttributeSelfTypeParameter<"">:$type); let builders = [ - AttrBuilderWithInferredContext<(ins "StringRef":$bytes, - "Type":$type), [{ - return $_get(type.getContext(), bytes, type); - }]>, - AttrBuilder<(ins "StringRef":$bytes), [{ - if (bytes.empty()) - return get($_ctxt); - return $_get($_ctxt, bytes, NoneType::get($_ctxt)); - }]>, - + AttrBuilderWithInferredContext<(ins "const Twine &":$bytes, "Type":$type)>, + /// Build an string attr with NoneType. + AttrBuilder<(ins "const Twine &":$bytes)>, /// Build an empty string attr with NoneType. AttrBuilder<(ins)> ]; diff --git a/mlir/include/mlir/IR/Identifier.h b/mlir/include/mlir/IR/Identifier.h --- a/mlir/include/mlir/IR/Identifier.h +++ b/mlir/include/mlir/IR/Identifier.h @@ -13,6 +13,7 @@ #include "llvm/ADT/DenseMapInfo.h" #include "llvm/ADT/PointerUnion.h" #include "llvm/ADT/StringMapEntry.h" +#include "llvm/ADT/Twine.h" #include "llvm/Support/PointerLikeTypeTraits.h" namespace mlir { @@ -38,7 +39,8 @@ public: /// Return an identifier for the specified string. - static Identifier get(StringRef str, MLIRContext *context); + static Identifier get(const Twine &string, MLIRContext *context); + Identifier(const Identifier &) = default; Identifier &operator=(const Identifier &other) = default; diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -70,9 +70,8 @@ SmallVector attributes; attributes.reserve(numElements); for (intptr_t i = 0; i < numElements; ++i) - attributes.emplace_back( - Identifier::get(unwrap(elements[i].name), unwrap(ctx)), - unwrap(elements[i].attribute)); + attributes.emplace_back(unwrap(elements[i].name), + unwrap(elements[i].attribute)); return wrap(DictionaryAttr::get(unwrap(ctx), attributes)); } diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp --- a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp +++ b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp @@ -177,7 +177,8 @@ // Set SPIR-V binary shader data as an attribute. vulkanLaunchCallOp->setAttr( kSPIRVBlobAttrName, - StringAttr::get(loc->getContext(), {binary.data(), binary.size()})); + StringAttr::get(loc->getContext(), + StringRef(binary.data(), binary.size()))); // Set entry point name as an attribute. vulkanLaunchCallOp->setAttr( diff --git a/mlir/lib/Dialect/GPU/Transforms/SerializeToBlob.cpp b/mlir/lib/Dialect/GPU/Transforms/SerializeToBlob.cpp --- a/mlir/lib/Dialect/GPU/Transforms/SerializeToBlob.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/SerializeToBlob.cpp @@ -66,7 +66,8 @@ return signalPassFailure(); // Add the blob as module attribute. - auto attr = StringAttr::get(&getContext(), {blob->data(), blob->size()}); + auto attr = + StringAttr::get(&getContext(), StringRef(blob->data(), blob->size())); getOperation()->setAttr(gpuBinaryAnnotation, attr); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -95,7 +95,7 @@ Operation *op) const { if (replacement.hasValue()) op->setAttr(LinalgTransforms::kLinalgTransformMarker, - rewriter.getStringAttr(replacement.getValue())); + rewriter.getStringAttr(replacement.getValue().strref())); else op->removeAttr(Identifier::get(LinalgTransforms::kLinalgTransformMarker, rewriter.getContext())); diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -291,12 +291,9 @@ if (parser.parseLParen() || parser.parseKeyword(&defval) || parser.parseRParen()) return failure(); - SmallString<16> attrval; // The def prefix is required for the attribute as "private" is a keyword // in C++. - attrval += "def"; - attrval += defval; - auto attr = parser.getBuilder().getStringAttr(attrval); + auto attr = parser.getBuilder().getStringAttr("def" + defval); result.addAttribute("default_val", attr); } else if (keyword == "proc_bind") { // Fail if there was already another proc_bind clause. diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -19,7 +19,7 @@ using namespace mlir; -Identifier Builder::getIdentifier(StringRef str) { +Identifier Builder::getIdentifier(const Twine &str) { return Identifier::get(str, context); } @@ -200,7 +200,7 @@ return FloatAttr::get(type, value); } -StringAttr Builder::getStringAttr(StringRef bytes) { +StringAttr Builder::getStringAttr(const Twine &bytes) { return StringAttr::get(context, bytes); } diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -197,10 +197,29 @@ return Base::get(context, ArrayRef()); } +//===----------------------------------------------------------------------===// +// StringAttr +//===----------------------------------------------------------------------===// + StringAttr StringAttr::getEmptyStringAttrUnchecked(MLIRContext *context) { return Base::get(context, "", NoneType::get(context)); } +/// Twine support for StringAttr. +StringAttr StringAttr::get(MLIRContext *context, const Twine &twine) { + // Fast-path empty twine. + if (twine.isTriviallyEmpty()) + return get(context); + SmallVector tempStr; + return Base::get(context, twine.toStringRef(tempStr), NoneType::get(context)); +} + +/// Twine support for StringAttr. +StringAttr StringAttr::get(const Twine &twine, Type type) { + SmallVector tempStr; + return Base::get(type.getContext(), twine.toStringRef(tempStr), type); +} + //===----------------------------------------------------------------------===// // FloatAttr //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -753,7 +753,10 @@ //===----------------------------------------------------------------------===// /// Return an identifier for the specified string. -Identifier Identifier::get(StringRef str, MLIRContext *context) { +Identifier Identifier::get(const Twine &string, MLIRContext *context) { + SmallString<32> tempStr; + auto str = string.toStringRef(tempStr); + // Check invariants after seeing if we already have something in the // identifier table - if we already had it in the table, then it already // passed invariant checks.