diff --git a/mlir/include/mlir/Dialect/GPU/GPUOps.td b/mlir/include/mlir/Dialect/GPU/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/GPUOps.td @@ -378,10 +378,10 @@ unsigned getNumKernelOperands(); /// The name of the kernel's containing module. - StringRef getKernelModuleName(); + StringAttr getKernelModuleName(); /// The name of the kernel. - StringRef getKernelName(); + StringAttr getKernelName(); /// The i-th operand passed to the kernel function. Value getKernelOperand(unsigned i); diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -98,9 +98,16 @@ StringAttr getStringAttr(const Twine &bytes); ArrayAttr getArrayAttr(ArrayRef value); FlatSymbolRefAttr getSymbolRefAttr(Operation *value); - FlatSymbolRefAttr getSymbolRefAttr(StringRef value); - SymbolRefAttr getSymbolRefAttr(StringRef value, + FlatSymbolRefAttr getSymbolRefAttr(StringAttr value); + SymbolRefAttr getSymbolRefAttr(StringAttr value, ArrayRef nestedReferences); + SymbolRefAttr getSymbolRefAttr(StringRef value, + ArrayRef nestedReferences) { + return getSymbolRefAttr(getStringAttr(value), nestedReferences); + } + FlatSymbolRefAttr getSymbolRefAttr(StringRef value) { + return getSymbolRefAttr(getStringAttr(value)); + } // Returns a 0-valued attribute of the given `type`. This function only // supports boolean, integer, and 16-/32-/64-bit float types, and vector or diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h --- a/mlir/include/mlir/IR/BuiltinAttributes.h +++ b/mlir/include/mlir/IR/BuiltinAttributes.h @@ -30,8 +30,10 @@ //===----------------------------------------------------------------------===// namespace detail { -template class ElementsAttrIterator; -template class ElementsAttrRange; +template +class ElementsAttrIterator; +template +class ElementsAttrRange; } // namespace detail /// A base attribute that represents a reference to a static shaped tensor or @@ -39,8 +41,10 @@ class ElementsAttr : public Attribute { public: using Attribute::Attribute; - template using iterator = detail::ElementsAttrIterator; - template using iterator_range = detail::ElementsAttrRange; + template + using iterator = detail::ElementsAttrIterator; + template + using iterator_range = detail::ElementsAttrRange; /// Return the type of this ElementsAttr, guaranteed to be a vector or tensor /// with static shape. @@ -52,14 +56,16 @@ /// Return the value of type 'T' at the given index, where 'T' corresponds to /// an Attribute type. - template T getValue(ArrayRef index) const { + template + T getValue(ArrayRef index) const { return getValue(index).template cast(); } /// Return the elements of this attribute as a value of type 'T'. Note: /// Aborts if the subclass is OpaqueElementsAttrs, these attrs do not support /// iteration. - template iterator_range getValues() const; + template + iterator_range getValues() const; /// Return if the given 'index' refers to a valid element in this attribute. bool isValidIndex(ArrayRef index) const; @@ -139,7 +145,8 @@ }; /// Type trait detector that checks if a given type T is a complex type. -template struct is_complex_t : public std::false_type {}; +template +struct is_complex_t : public std::false_type {}; template struct is_complex_t> : public std::true_type {}; } // namespace detail @@ -154,7 +161,8 @@ /// floating point type that can be used to access the underlying element /// types of a DenseElementsAttr. // TODO: Use std::disjunction when C++17 is supported. - template struct is_valid_cpp_fp_type { + template + struct is_valid_cpp_fp_type { /// The type is a valid floating point type if it is a builtin floating /// point type, or is a potentially user defined floating point type. The /// latter allows for supporting users that have custom types defined for @@ -423,7 +431,8 @@ Attribute getValue(ArrayRef index) const { return getValue(index); } - template T getValue(ArrayRef index) const { + template + T getValue(ArrayRef index) const { // Skip to the element corresponding to the flattened index. return *std::next(getValues().begin(), getFlattenedIndex(index)); } @@ -680,8 +689,15 @@ return SymbolRefAttr::get(ctx, value); } + static FlatSymbolRefAttr get(StringAttr value) { + return SymbolRefAttr::get(value); + } + + /// Returns the name of the held symbol reference as a StringAttr. + StringAttr getAttr() const { return getRootReference(); } + /// Returns the name of the held symbol reference. - StringRef getValue() const { return getRootReference(); } + StringRef getValue() const { return getAttr().getValue(); } /// Methods for support type inquiry through isa, cast, and dyn_cast. static bool classof(Attribute attr) { @@ -845,22 +861,28 @@ } /// Utility functors used to generically implement the iterators methods. - template struct PlusAssign { + template + struct PlusAssign { void operator()(ItT &it, ptrdiff_t offset) { it += offset; } }; - template struct Minus { + template + struct Minus { ptrdiff_t operator()(const ItT &lhs, const ItT &rhs) { return lhs - rhs; } }; - template struct MinusAssign { + template + struct MinusAssign { void operator()(ItT &it, ptrdiff_t offset) { it -= offset; } }; - template struct Dereference { + template + struct Dereference { T operator()(ItT &it) { return *it; } }; - template struct ConstructIter { + template + struct ConstructIter { void operator()(ItT &dest, const ItT &it) { ::new (&dest) ItT(it); } }; - template struct DestructIter { + template + struct DestructIter { void operator()(ItT &it) { it.~ItT(); } }; diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -881,17 +881,26 @@ @parent_reference::@nested_reference ``` }]; - let parameters = (ins - StringRefParameter<"">:$rootReference, - ArrayRefParameter<"FlatSymbolRefAttr", "">:$nestedReferences - ); + let parameters = + (ins "StringAttr":$rootReference, + ArrayRefParameter<"FlatSymbolRefAttr", "">:$nestedReferences); + + let builders = [ + AttrBuilderWithInferredContext< + (ins "StringAttr":$rootReference, + "ArrayRef":$nestedReferences), [{ + return $_get(rootReference.getContext(), rootReference, nestedReferences); + }]>, + ]; let extraClassDeclaration = [{ static FlatSymbolRefAttr get(MLIRContext *ctx, StringRef value); + static FlatSymbolRefAttr get(StringAttr value); /// Returns the name of the fully resolved symbol, i.e. the leaf of the /// reference path. - StringRef getLeafReference() const; + StringAttr getLeafReference() const; }]; + let skipDefaultBuilders = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1734,7 +1734,7 @@ class ReferToOp : AttrConstraint< CPred<"isa_and_nonnull<" # opClass # ">(" "::mlir::SymbolTable::lookupNearestSymbolFrom(" - "&$_op, $_self.cast<::mlir::FlatSymbolRefAttr>().getValue()))">, + "&$_op, $_self.cast<::mlir::FlatSymbolRefAttr>().getAttr()))">, "referencing to a '" # opClass # "' symbol">; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/SymbolInterfaces.td b/mlir/include/mlir/IR/SymbolInterfaces.td --- a/mlir/include/mlir/IR/SymbolInterfaces.td +++ b/mlir/include/mlir/IR/SymbolInterfaces.td @@ -31,7 +31,7 @@ let methods = [ InterfaceMethod<"Returns the name of this symbol.", - "StringRef", "getName", (ins), [{ + "StringAttr", "getNameAttr", (ins), [{ // Don't rely on the trait implementation as optional symbol operations // may override this. return mlir::SymbolTable::getSymbolName($_op); @@ -40,11 +40,10 @@ }] >, InterfaceMethod<"Sets the name of this symbol.", - "void", "setName", (ins "StringRef":$name), [{}], + "void", "setName", (ins "StringAttr":$name), [{}], /*defaultImplementation=*/[{ this->getOperation()->setAttr( - mlir::SymbolTable::getSymbolAttrName(), - StringAttr::get(this->getOperation()->getContext(), name)); + mlir::SymbolTable::getSymbolAttrName(), name); }] >, InterfaceMethod<"Gets the visibility of this symbol.", @@ -122,7 +121,7 @@ symbol 'newSymbol' that are nested within the given operation 'from'. Note: See mlir::SymbolTable::replaceAllSymbolUses for more details. }], - "LogicalResult", "replaceAllSymbolUses", (ins "StringRef":$newSymbol, + "LogicalResult", "replaceAllSymbolUses", (ins "StringAttr":$newSymbol, "Operation *":$from), [{}], /*defaultImplementation=*/[{ return ::mlir::SymbolTable::replaceAllSymbolUses(this->getOperation(), @@ -176,6 +175,16 @@ }]; let extraClassDeclaration = [{ + /// Convenience version of `getNameAttr` that returns a StringRef. + StringRef getName() { + return getNameAttr().getValue(); + } + + /// Convenience version of `setName` that take a StringRef. + void setName(StringRef name) { + setName(StringAttr::get(this->getContext(), name)); + } + /// Custom classof that handles the case where the symbol is optional. static bool classof(Operation *op) { auto *opConcept = getInterfaceFor(op); @@ -188,6 +197,16 @@ let extraTraitClassDeclaration = [{ using Visibility = mlir::SymbolTable::Visibility; + + /// Convenience version of `getNameAttr` that returns a StringRef. + StringRef getName() { + return getNameAttr().getValue(); + } + + /// Convenience version of `setName` that take a StringRef. + void setName(StringRef name) { + setName(StringAttr::get(this->getContext(), name)); + } }]; } diff --git a/mlir/include/mlir/IR/SymbolTable.h b/mlir/include/mlir/IR/SymbolTable.h --- a/mlir/include/mlir/IR/SymbolTable.h +++ b/mlir/include/mlir/IR/SymbolTable.h @@ -30,7 +30,16 @@ /// Look up a symbol with the specified name, returning null if no such /// name exists. Names never include the @ on them. Operation *lookup(StringRef name) const; - template T lookup(StringRef name) const { + template + T lookup(StringRef name) const { + return dyn_cast_or_null(lookup(name)); + } + + /// Look up a symbol with the specified name, returning null if no such + /// name exists. Names never include the @ on them. + Operation *lookup(StringAttr name) const; + template + T lookup(StringAttr name) const { return dyn_cast_or_null(lookup(name)); } @@ -74,10 +83,15 @@ Nested, }; - /// Returns the name of the given symbol operation. - static StringRef getSymbolName(Operation *symbol); + /// Returns the name of the given symbol operation, aborting if no symbol is + /// present. + static StringAttr getSymbolName(Operation *symbol); + /// Sets the name of the given symbol operation. - static void setSymbolName(Operation *symbol, StringRef name); + static void setSymbolName(Operation *symbol, StringAttr name); + static void setSymbolName(Operation *symbol, StringRef name) { + setSymbolName(symbol, StringAttr::get(symbol->getContext(), name)); + } /// Returns the visibility of the given symbol operation. static Visibility getSymbolVisibility(Operation *symbol); @@ -100,7 +114,10 @@ /// Returns the operation registered with the given symbol name with the /// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation /// with the 'OpTrait::SymbolTable' trait. - static Operation *lookupSymbolIn(Operation *op, StringRef symbol); + static Operation *lookupSymbolIn(Operation *op, StringAttr symbol); + static Operation *lookupSymbolIn(Operation *op, StringRef symbol) { + return lookupSymbolIn(op, StringAttr::get(op->getContext(), symbol)); + } static Operation *lookupSymbolIn(Operation *op, SymbolRefAttr symbol); /// A variant of 'lookupSymbolIn' that returns all of the symbols referenced /// by a given SymbolRefAttr. Returns failure if any of the nested references @@ -112,11 +129,11 @@ /// closest parent operation of, or including, 'from' with the /// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was /// found. - static Operation *lookupNearestSymbolFrom(Operation *from, StringRef symbol); + static Operation *lookupNearestSymbolFrom(Operation *from, StringAttr symbol); static Operation *lookupNearestSymbolFrom(Operation *from, SymbolRefAttr symbol); template - static T lookupNearestSymbolFrom(Operation *from, StringRef symbol) { + static T lookupNearestSymbolFrom(Operation *from, StringAttr symbol) { return dyn_cast_or_null(lookupNearestSymbolFrom(from, symbol)); } template @@ -169,9 +186,9 @@ /// operation 'from'. This does not traverse into any nested symbol tables. /// This function returns None if there are any unknown operations that may /// potentially be symbol tables. - static Optional getSymbolUses(StringRef symbol, Operation *from); + static Optional getSymbolUses(StringAttr symbol, Operation *from); static Optional getSymbolUses(Operation *symbol, Operation *from); - static Optional getSymbolUses(StringRef symbol, Region *from); + static Optional getSymbolUses(StringAttr symbol, Region *from); static Optional getSymbolUses(Operation *symbol, Region *from); /// Return if the given symbol is known to have no uses that are nested @@ -180,9 +197,9 @@ /// unknown operations that may potentially be symbol tables. This doesn't /// necessarily mean that there are no uses, we just can't conservatively /// prove it. - static bool symbolKnownUseEmpty(StringRef symbol, Operation *from); + static bool symbolKnownUseEmpty(StringAttr symbol, Operation *from); static bool symbolKnownUseEmpty(Operation *symbol, Operation *from); - static bool symbolKnownUseEmpty(StringRef symbol, Region *from); + static bool symbolKnownUseEmpty(StringAttr symbol, Region *from); static bool symbolKnownUseEmpty(Operation *symbol, Region *from); /// Attempt to replace all uses of the given symbol 'oldSymbol' with the @@ -190,23 +207,24 @@ /// 'from'. This does not traverse into any nested symbol tables. If there are /// any unknown operations that may potentially be symbol tables, no uses are /// replaced and failure is returned. - static LogicalResult replaceAllSymbolUses(StringRef oldSymbol, - StringRef newSymbol, + static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, + StringAttr newSymbol, Operation *from); static LogicalResult replaceAllSymbolUses(Operation *oldSymbol, - StringRef newSymbolName, + StringAttr newSymbolName, Operation *from); - static LogicalResult replaceAllSymbolUses(StringRef oldSymbol, - StringRef newSymbol, Region *from); + static LogicalResult replaceAllSymbolUses(StringAttr oldSymbol, + StringAttr newSymbol, Region *from); static LogicalResult replaceAllSymbolUses(Operation *oldSymbol, - StringRef newSymbolName, + StringAttr newSymbolName, Region *from); private: Operation *symbolTableOp; - /// This is a mapping from a name to the symbol with that name. - llvm::StringMap symbolTable; + /// This is a mapping from a name to the symbol with that name. They key is + /// always known to be a StringAttr. + DenseMap symbolTable; /// This is used when name conflicts are detected. unsigned uniquingCounter = 0; @@ -226,7 +244,7 @@ public: /// Look up a symbol with the specified name within the specified symbol table /// operation, returning null if no such name exists. - Operation *lookupSymbolIn(Operation *symbolTableOp, StringRef symbol); + Operation *lookupSymbolIn(Operation *symbolTableOp, StringAttr symbol); Operation *lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr name); template T lookupSymbolIn(Operation *symbolTableOp, NameT &&name) const { @@ -244,10 +262,10 @@ /// closest parent operation of, or including, 'from' with the /// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was /// found. - Operation *lookupNearestSymbolFrom(Operation *from, StringRef symbol); + Operation *lookupNearestSymbolFrom(Operation *from, StringAttr symbol); Operation *lookupNearestSymbolFrom(Operation *from, SymbolRefAttr symbol); template - T lookupNearestSymbolFrom(Operation *from, StringRef symbol) { + T lookupNearestSymbolFrom(Operation *from, StringAttr symbol) { return dyn_cast_or_null(lookupNearestSymbolFrom(from, symbol)); } template @@ -290,7 +308,7 @@ } /// Replace all of the uses of the given symbol with `newSymbolName`. - void replaceAllUsesWith(Operation *symbol, StringRef newSymbolName); + void replaceAllUsesWith(Operation *symbol, StringAttr newSymbolName); private: /// A reference to the symbol table used to construct this map. @@ -327,18 +345,28 @@ /// Look up a symbol with the specified name, returning null if no such /// name exists. Symbol names never include the @ on them. Note: This /// performs a linear scan of held symbols. - Operation *lookupSymbol(StringRef name) { + Operation *lookupSymbol(StringAttr name) { return mlir::SymbolTable::lookupSymbolIn(this->getOperation(), name); } - template T lookupSymbol(StringRef name) { + template + T lookupSymbol(StringAttr name) { return dyn_cast_or_null(lookupSymbol(name)); } Operation *lookupSymbol(SymbolRefAttr symbol) { return mlir::SymbolTable::lookupSymbolIn(this->getOperation(), symbol); } - template T lookupSymbol(SymbolRefAttr symbol) { + template + T lookupSymbol(SymbolRefAttr symbol) { return dyn_cast_or_null(lookupSymbol(symbol)); } + + Operation *lookupSymbol(StringRef name) { + return mlir::SymbolTable::lookupSymbolIn(this->getOperation(), name); + } + template + T lookupSymbol(StringRef name) { + return dyn_cast_or_null(lookupSymbol(name)); + } }; } // end namespace OpTrait diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -212,15 +212,16 @@ refs.reserve(numReferences); for (intptr_t i = 0; i < numReferences; ++i) refs.push_back(unwrap(references[i]).cast()); - return wrap(SymbolRefAttr::get(unwrap(ctx), unwrap(symbol), refs)); + auto symbolAttr = StringAttr::get(unwrap(ctx), unwrap(symbol)); + return wrap(SymbolRefAttr::get(symbolAttr, refs)); } MlirStringRef mlirSymbolRefAttrGetRootReference(MlirAttribute attr) { - return wrap(unwrap(attr).cast().getRootReference()); + return wrap(unwrap(attr).cast().getRootReference().getValue()); } MlirStringRef mlirSymbolRefAttrGetLeafReference(MlirAttribute attr) { - return wrap(unwrap(attr).cast().getLeafReference()); + return wrap(unwrap(attr).cast().getLeafReference().getValue()); } intptr_t mlirSymbolRefAttrGetNumNestedReferences(MlirAttribute attr) { diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -704,7 +704,8 @@ // Get the function from the module. The name corresponds to the name of // the kernel function. auto kernelName = generateKernelNameConstant( - launchOp.getKernelModuleName(), launchOp.getKernelName(), loc, rewriter); + launchOp.getKernelModuleName().getValue(), + launchOp.getKernelName().getValue(), loc, rewriter); auto function = moduleGetFunctionCallBuilder.create( loc, rewriter, {module.getResult(0), kernelName}); auto zero = rewriter.create(loc, llvmInt32Type, diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h --- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h +++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h @@ -106,7 +106,8 @@ Operation *op) const { using LLVM::LLVMFuncOp; - Operation *funcOp = SymbolTable::lookupNearestSymbolFrom(op, funcName); + auto funcAttr = StringAttr::get(op->getContext(), funcName); + Operation *funcOp = SymbolTable::lookupNearestSymbolFrom(op, funcAttr); if (funcOp) return cast(*funcOp); diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp --- a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp +++ b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp @@ -181,9 +181,8 @@ StringRef(binary.data(), binary.size()))); // Set entry point name as an attribute. - vulkanLaunchCallOp->setAttr( - kSPIRVEntryPointAttrName, - StringAttr::get(loc->getContext(), launchOp.getKernelName())); + vulkanLaunchCallOp->setAttr(kSPIRVEntryPointAttrName, + launchOp.getKernelName()); launchOp.erase(); } diff --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp --- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp +++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp @@ -52,9 +52,8 @@ // fnName is a dynamic std::string, unique it via a SymbolRefAttr. FlatSymbolRefAttr fnNameAttr = rewriter.getSymbolRefAttr(fnName); auto module = op->getParentOfType(); - if (module.lookupSymbol(fnName)) { + if (module.lookupSymbol(fnNameAttr.getAttr())) return fnNameAttr; - } SmallVector inputTypes(extractOperandTypes(op)); assert(op->getNumResults() == 0 && diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp @@ -127,8 +127,9 @@ // {spv_module_name}_{function_name} auto entryPoint = *module.getOps().begin(); StringRef funcName = entryPoint.fn(); - auto funcOp = module.lookupSymbol(funcName); - std::string newFuncName = spvModuleName.str() + "_" + funcName.str(); + auto funcOp = module.lookupSymbol(entryPoint.fnAttr()); + StringAttr newFuncName = + StringAttr::get(module->getContext(), spvModuleName + "_" + funcName); if (failed(SymbolTable::replaceAllSymbolUses(funcOp, newFuncName, module))) return failure(); SymbolTable::setSymbolName(funcOp, newFuncName); @@ -166,9 +167,10 @@ // is named: // __spv__{kernel_module_name} // based on GPU to SPIR-V conversion. - StringRef kernelModuleName = launchOp.getKernelModuleName(); + StringRef kernelModuleName = launchOp.getKernelModuleName().getValue(); std::string spvModuleName = kSPIRVModule + kernelModuleName.str(); - auto spvModule = module.lookupSymbol(spvModuleName); + auto spvModule = module.lookupSymbol( + StringAttr::get(context, spvModuleName)); if (!spvModule) { return launchOp.emitOpError("SPIR-V kernel module '") << spvModuleName << "' is not found"; @@ -180,9 +182,10 @@ // variables. The name of the kernel will be // {spv_module_name}_{kernel_function_name} // to avoid symbolic name conflicts. - StringRef kernelFuncName = launchOp.getKernelName(); + StringRef kernelFuncName = launchOp.getKernelName().getValue(); std::string newKernelFuncName = spvModuleName + "_" + kernelFuncName.str(); - auto kernelFunc = module.lookupSymbol(newKernelFuncName); + auto kernelFunc = module.lookupSymbol( + StringAttr::get(context, newKernelFuncName)); if (!kernelFunc) { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(module.getBody()); diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp @@ -1523,12 +1523,13 @@ llvm::formatv("{0}_descriptor_set{1}_binding{2}", moduleAndName, std::to_string(descriptorSet.getInt()), std::to_string(binding.getInt())); + auto nameAttr = StringAttr::get(op->getContext(), name); // Replace all symbol uses and set the new symbol name. Finally, remove // descriptor set and binding attributes. - if (failed(SymbolTable::replaceAllSymbolUses(op, name, spvModule))) + if (failed(SymbolTable::replaceAllSymbolUses(op, nameAttr, spvModule))) op.emitError("unable to replace all symbol uses for ") << name; - SymbolTable::setSymbolName(op, name); + SymbolTable::setSymbolName(op, nameAttr); op->removeAttr(kDescriptorSet); op->removeAttr(kBinding); } diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -196,14 +196,15 @@ return success(); // Check that `launch_func` refers to a well-formed GPU kernel module. - StringRef kernelModuleName = launchOp.getKernelModuleName(); + StringAttr kernelModuleName = launchOp.getKernelModuleName(); auto kernelModule = module.lookupSymbol(kernelModuleName); if (!kernelModule) return launchOp.emitOpError() - << "kernel module '" << kernelModuleName << "' is undefined"; + << "kernel module '" << kernelModuleName.getValue() + << "' is undefined"; // Check that `launch_func` refers to a well-formed kernel function. - Operation *kernelFunc = module.lookupSymbol(launchOp.kernel()); + Operation *kernelFunc = module.lookupSymbol(launchOp.kernelAttr()); auto kernelGPUFunction = dyn_cast_or_null(kernelFunc); auto kernelLLVMFunction = dyn_cast_or_null(kernelFunc); if (!kernelGPUFunction && !kernelLLVMFunction) @@ -555,11 +556,11 @@ return getNumOperands() - asyncDependencies().size() - kNumConfigOperands; } -StringRef LaunchFuncOp::getKernelModuleName() { +StringAttr LaunchFuncOp::getKernelModuleName() { return kernel().getRootReference(); } -StringRef LaunchFuncOp::getKernelName() { return kernel().getLeafReference(); } +StringAttr LaunchFuncOp::getKernelName() { return kernel().getLeafReference(); } Value LaunchFuncOp::getKernelOperand(unsigned i) { return getOperand(asyncDependencies().size() + kNumConfigOperands + i); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -343,8 +343,8 @@ // a constraint in the operation definition. for (SymbolRefAttr symbolRef : attribute.cast().getAsRange()) { - StringRef metadataName = symbolRef.getRootReference(); - StringRef symbolName = symbolRef.getLeafReference(); + StringAttr metadataName = symbolRef.getRootReference(); + StringAttr symbolName = symbolRef.getLeafReference(); // We want @metadata::@symbol, not just @symbol if (metadataName == symbolName) { return op->emitOpError() << "expected '" << symbolRef @@ -770,7 +770,7 @@ bool isIndirect = false; // If this is an indirect call, the callee attribute is missing. - Optional calleeName = op.callee(); + FlatSymbolRefAttr calleeName = op.calleeAttr(); if (!calleeName) { isIndirect = true; if (!op.getNumOperands()) @@ -782,14 +782,15 @@ << ptrType; fnType = ptrType.getElementType(); } else { - Operation *callee = SymbolTable::lookupNearestSymbolFrom(op, *calleeName); + Operation *callee = + SymbolTable::lookupNearestSymbolFrom(op, calleeName.getAttr()); if (!callee) return op.emitOpError() - << "'" << *calleeName + << "'" << calleeName.getValue() << "' does not reference a symbol in the current scope"; auto fn = dyn_cast(callee); if (!fn) - return op.emitOpError() << "'" << *calleeName + return op.emitOpError() << "'" << calleeName.getValue() << "' does not reference a valid LLVM function"; fnType = fn.getType(); @@ -2253,14 +2254,14 @@ if (!accessGroupRef) return op->emitOpError() << "expected '" << attr << "' to be a symbol reference"; - StringRef metadataName = accessGroupRef.getRootReference(); + StringAttr metadataName = accessGroupRef.getRootReference(); auto metadataOp = SymbolTable::lookupNearestSymbolFrom( op->getParentOp(), metadataName); if (!metadataOp) return op->emitOpError() << "expected '" << attr << "' to reference a metadata op"; - StringRef accessGroupName = accessGroupRef.getLeafReference(); + StringAttr accessGroupName = accessGroupRef.getLeafReference(); Operation *accessGroupOp = SymbolTable::lookupNearestSymbolFrom(metadataOp, accessGroupName); if (!accessGroupOp) diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -1066,7 +1066,7 @@ static LogicalResult verify(spirv::AddressOfOp addressOfOp) { auto varOp = dyn_cast_or_null( SymbolTable::lookupNearestSymbolFrom(addressOfOp->getParentOp(), - addressOfOp.variable())); + addressOfOp.variableAttr())); if (!varOp) { return addressOfOp.emitOpError("expected spv.GlobalVariable symbol"); } @@ -1953,14 +1953,14 @@ //===----------------------------------------------------------------------===// static LogicalResult verify(spirv::FunctionCallOp functionCallOp) { - auto fnName = functionCallOp.callee(); + auto fnName = functionCallOp.calleeAttr(); auto funcOp = dyn_cast_or_null(SymbolTable::lookupNearestSymbolFrom( functionCallOp->getParentOp(), fnName)); if (!funcOp) { return functionCallOp.emitOpError("callee function '") - << fnName << "' not found in nearest symbol table"; + << fnName.getValue() << "' not found in nearest symbol table"; } auto functionType = funcOp.getType(); @@ -2115,7 +2115,7 @@ if (auto init = varOp->getAttrOfType(kInitializerAttrName)) { Operation *initOp = SymbolTable::lookupNearestSymbolFrom( - varOp->getParentOp(), init.getValue()); + varOp->getParentOp(), init.getAttr()); // TODO: Currently only variable initialization with specialization // constants and other variables is supported. They could be normal // constants in the module scope as well. @@ -2691,7 +2691,7 @@ static LogicalResult verify(spirv::ReferenceOfOp referenceOfOp) { auto *specConstSym = SymbolTable::lookupNearestSymbolFrom( - referenceOfOp->getParentOp(), referenceOfOp.spec_const()); + referenceOfOp->getParentOp(), referenceOfOp.spec_constAttr()); Type constType; auto specConstOp = dyn_cast_or_null(specConstSym); @@ -3516,17 +3516,17 @@ if (cType.isa()) return constOp.emitError("unsupported composite type ") << cType; - else if (constituents.size() != cType.getNumElements()) + if (constituents.size() != cType.getNumElements()) return constOp.emitError("has incorrect number of operands: expected ") << cType.getNumElements() << ", but provided " << constituents.size(); for (auto index : llvm::seq(0, constituents.size())) { - auto constituent = constituents[index].dyn_cast(); + auto constituent = constituents[index].cast(); auto constituentSpecConstOp = dyn_cast(SymbolTable::lookupNearestSymbolFrom( - constOp->getParentOp(), constituent.getValue())); + constOp->getParentOp(), constituent.getAttr())); if (constituentSpecConstOp.default_value().getType() != cType.getElementType(index)) diff --git a/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp b/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp --- a/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp +++ b/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp @@ -30,21 +30,20 @@ /// Returns an unsed symbol in `module` for `oldSymbolName` by trying numeric /// suffix in `lastUsedID`. -static SmallString<64> renameSymbol(StringRef oldSymName, unsigned &lastUsedID, - spirv::ModuleOp module) { +static StringAttr renameSymbol(StringRef oldSymName, unsigned &lastUsedID, + spirv::ModuleOp module) { SmallString<64> newSymName(oldSymName); newSymName.push_back('_'); - while (lastUsedID < maxFreeID) { - std::string possible = (newSymName + llvm::utostr(++lastUsedID)).str(); + MLIRContext *ctx = module->getContext(); - if (!SymbolTable::lookupSymbolIn(module, possible)) { - newSymName += llvm::utostr(lastUsedID); - break; - } + while (lastUsedID < maxFreeID) { + auto possible = StringAttr::get(ctx, newSymName + Twine(++lastUsedID)); + if (!SymbolTable::lookupSymbolIn(module, possible)) + return possible; } - return newSymName; + return StringAttr::get(ctx, newSymName); } /// Checks if a symbol with the same name as `op` already exists in `source`. @@ -57,7 +56,7 @@ return success(); StringRef oldSymName = op.getName(); - SmallString<64> newSymName = renameSymbol(oldSymName, lastUsedID, target); + StringAttr newSymName = renameSymbol(oldSymName, lastUsedID, target); if (failed(SymbolTable::replaceAllSymbolUses(op, newSymName, target))) return op.emitError("unable to update all symbol uses for ") @@ -234,7 +233,7 @@ SymbolOpInterface replacementSymOp = result.first->second; if (failed(SymbolTable::replaceAllSymbolUses( - symbolOp, replacementSymOp.getName(), combinedModule))) { + symbolOp, replacementSymOp.getNameAttr(), combinedModule))) { symbolOp.emitError("unable to update all symbol uses for ") << symbolOp.getName() << " to " << replacementSymOp.getName(); return nullptr; diff --git a/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp @@ -64,11 +64,11 @@ LogicalResult matchAndRewrite(spirv::AddressOfOp op, PatternRewriter &rewriter) const override { auto spirvModule = op->getParentOfType(); - auto varName = op.variable(); + auto varName = op.variableAttr(); auto varOp = spirvModule.lookupSymbol(varName); rewriter.replaceOpWithNewOp( - op, varOp.type(), rewriter.getSymbolRefAttr(varName)); + op, varOp.type(), rewriter.getSymbolRefAttr(varName.getAttr())); return success(); } }; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -96,19 +96,21 @@ } /// Returns function reference (first hit also inserts into module). -static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, Type result, +static FlatSymbolRefAttr getFunc(Operation *op, StringRef name, Type resultType, ValueRange operands) { MLIRContext *context = op->getContext(); auto module = op->getParentOfType(); - auto func = module.lookupSymbol(name); + auto result = SymbolRefAttr::get(context, name); + auto func = module.lookupSymbol(result.getAttr()); if (!func) { OpBuilder moduleBuilder(module.getBodyRegion()); moduleBuilder - .create(op->getLoc(), name, - FunctionType::get(context, operands.getTypes(), result)) + .create( + op->getLoc(), name, + FunctionType::get(context, operands.getTypes(), resultType)) .setPrivate(); } - return SymbolRefAttr::get(context, name); + return result; } /// Generates a call into the "swiss army knife" method of the sparse runtime diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1659,7 +1659,7 @@ printType(typeAttr.getValue()); } else if (auto refAttr = attr.dyn_cast()) { - printSymbolReference(refAttr.getRootReference(), os); + printSymbolReference(refAttr.getRootReference().getValue(), os); for (FlatSymbolRefAttr nestedRef : refAttr.getNestedReferences()) { os << "::"; printSymbolReference(nestedRef.getValue(), os); diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -216,13 +216,15 @@ assert(symName && "value does not have a valid symbol name"); return getSymbolRefAttr(symName.getValue()); } -FlatSymbolRefAttr Builder::getSymbolRefAttr(StringRef value) { - return SymbolRefAttr::get(getContext(), value); + +FlatSymbolRefAttr Builder::getSymbolRefAttr(StringAttr value) { + return SymbolRefAttr::get(value); } + SymbolRefAttr -Builder::getSymbolRefAttr(StringRef value, +Builder::getSymbolRefAttr(StringAttr value, ArrayRef nestedReferences) { - return SymbolRefAttr::get(getContext(), value, nestedReferences); + return SymbolRefAttr::get(value, nestedReferences); } ArrayAttr Builder::getBoolArrayAttr(ArrayRef values) { diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -273,12 +273,16 @@ //===----------------------------------------------------------------------===// FlatSymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value) { - return get(ctx, value, llvm::None).cast(); + return get(StringAttr::get(ctx, value)); } -StringRef SymbolRefAttr::getLeafReference() const { +FlatSymbolRefAttr SymbolRefAttr::get(StringAttr value) { + return get(value, {}).cast(); +} + +StringAttr SymbolRefAttr::getLeafReference() const { ArrayRef nestedRefs = getNestedReferences(); - return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getValue(); + return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getAttr(); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp --- a/mlir/lib/IR/SymbolTable.cpp +++ b/mlir/lib/IR/SymbolTable.cpp @@ -22,17 +22,13 @@ return op->getNumRegions() == 1 && !op->getDialect(); } -/// Returns the string name of the given symbol, or None if this is not a +/// Returns the string name of the given symbol, or null if this is not a /// symbol. -static Optional getNameIfSymbol(Operation *symbol) { - auto nameAttr = - symbol->getAttrOfType(SymbolTable::getSymbolAttrName()); - return nameAttr ? nameAttr.getValue() : Optional(); +static StringAttr getNameIfSymbol(Operation *op) { + return op->getAttrOfType(SymbolTable::getSymbolAttrName()); } -static Optional getNameIfSymbol(Operation *symbol, - Identifier symbolAttrNameId) { - auto nameAttr = symbol->getAttrOfType(symbolAttrNameId); - return nameAttr ? nameAttr.getValue() : Optional(); +static StringAttr getNameIfSymbol(Operation *op, Identifier symbolAttrNameId) { + return op->getAttrOfType(symbolAttrNameId); } /// Computes the nested symbol reference attribute for the symbol 'symbolName' @@ -40,13 +36,13 @@ /// to the given operation 'within', where 'within' is an ancestor of 'symbol'. /// Returns success if all references up to 'within' could be computed. static LogicalResult -collectValidReferencesFor(Operation *symbol, StringRef symbolName, +collectValidReferencesFor(Operation *symbol, StringAttr symbolName, Operation *within, SmallVectorImpl &results) { assert(within->isAncestor(symbol) && "expected 'within' to be an ancestor"); MLIRContext *ctx = symbol->getContext(); - auto leafRef = FlatSymbolRefAttr::get(ctx, symbolName); + auto leafRef = FlatSymbolRefAttr::get(symbolName); results.push_back(leafRef); // Early exit for when 'within' is the parent of 'symbol'. @@ -63,17 +59,16 @@ if (!symbolTableOp->hasTrait()) return failure(); // Each parent of 'symbol' should also be a symbol. - Optional symbolTableName = - getNameIfSymbol(symbolTableOp, symbolNameId); + StringAttr symbolTableName = getNameIfSymbol(symbolTableOp, symbolNameId); if (!symbolTableName) return failure(); - results.push_back(SymbolRefAttr::get(ctx, *symbolTableName, nestedRefs)); + results.push_back(SymbolRefAttr::get(symbolTableName, nestedRefs)); symbolTableOp = symbolTableOp->getParentOp(); if (symbolTableOp == within) break; nestedRefs.insert(nestedRefs.begin(), - FlatSymbolRefAttr::get(ctx, *symbolTableName)); + FlatSymbolRefAttr::get(symbolTableName)); } while (true); return success(); } @@ -119,11 +114,11 @@ Identifier symbolNameId = Identifier::get(SymbolTable::getSymbolAttrName(), symbolTableOp->getContext()); for (auto &op : symbolTableOp->getRegion(0).front()) { - Optional name = getNameIfSymbol(&op, symbolNameId); + StringAttr name = getNameIfSymbol(&op, symbolNameId); if (!name) continue; - auto inserted = symbolTable.insert({*name, &op}); + auto inserted = symbolTable.insert({name, &op}); (void)inserted; assert(inserted.second && "expected region to contain uniquely named symbol operations"); @@ -133,18 +128,21 @@ /// Look up a symbol with the specified name, returning null if no such name /// exists. Names never include the @ on them. Operation *SymbolTable::lookup(StringRef name) const { + return lookup(StringAttr::get(symbolTableOp->getContext(), name)); +} +Operation *SymbolTable::lookup(StringAttr name) const { return symbolTable.lookup(name); } /// Erase the given symbol from the table. void SymbolTable::erase(Operation *symbol) { - Optional name = getNameIfSymbol(symbol); + StringAttr name = getNameIfSymbol(symbol); assert(name && "expected valid 'name' attribute"); assert(symbol->getParentOp() == symbolTableOp && "expected this operation to be inside of the operation with this " "SymbolTable"); - auto it = symbolTable.find(*name); + auto it = symbolTable.find(name); if (it != symbolTable.end() && it->second == symbol) { symbolTable.erase(it); symbol->erase(); @@ -180,7 +178,7 @@ // Add this symbol to the symbol table, uniquing the name if a conflict is // detected. - StringRef name = getSymbolName(symbol); + StringAttr name = getSymbolName(symbol); if (symbolTable.insert({name, symbol}).second) return; // If the symbol was already in the table, also return. @@ -188,28 +186,31 @@ return; // If a conflict was detected, then the symbol will not have been added to // the symbol table. Try suffixes until we get to a unique name that works. - SmallString<128> nameBuffer(name); + SmallString<128> nameBuffer(name.getValue()); unsigned originalLength = nameBuffer.size(); + MLIRContext *context = symbol->getContext(); + // Iteratively try suffixes until we find one that isn't used. do { nameBuffer.resize(originalLength); nameBuffer += '_'; nameBuffer += std::to_string(uniquingCounter++); - } while (!symbolTable.insert({nameBuffer, symbol}).second); + } while (!symbolTable.insert({StringAttr::get(context, nameBuffer), symbol}) + .second); setSymbolName(symbol, nameBuffer); } /// Returns the name of the given symbol operation. -StringRef SymbolTable::getSymbolName(Operation *symbol) { - Optional name = getNameIfSymbol(symbol); +StringAttr SymbolTable::getSymbolName(Operation *symbol) { + StringAttr name = getNameIfSymbol(symbol); assert(name && "expected valid symbol name"); - return *name; + return name; } + /// Sets the name of the given symbol operation. -void SymbolTable::setSymbolName(Operation *symbol, StringRef name) { - symbol->setAttr(getSymbolAttrName(), - StringAttr::get(symbol->getContext(), name)); +void SymbolTable::setSymbolName(Operation *symbol, StringAttr name) { + symbol->setAttr(getSymbolAttrName(), name); } /// Returns the visibility of the given symbol operation. @@ -295,7 +296,7 @@ /// with the 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol /// was found. Operation *SymbolTable::lookupSymbolIn(Operation *symbolTableOp, - StringRef symbol) { + StringAttr symbol) { assert(symbolTableOp->hasTrait()); Region ®ion = symbolTableOp->getRegion(0); if (region.empty()) @@ -322,7 +323,7 @@ static LogicalResult lookupSymbolInImpl( Operation *symbolTableOp, SymbolRefAttr symbol, SmallVectorImpl &symbols, - function_ref lookupSymbolFn) { + function_ref lookupSymbolFn) { assert(symbolTableOp->hasTrait()); // Lookup the root reference for this symbol. @@ -343,7 +344,7 @@ // Otherwise, lookup each of the nested non-leaf references and ensure that // each corresponds to a valid symbol table. for (FlatSymbolRefAttr ref : nestedRefs.drop_back()) { - symbolTableOp = lookupSymbolFn(symbolTableOp, ref.getValue()); + symbolTableOp = lookupSymbolFn(symbolTableOp, ref.getAttr()); if (!symbolTableOp || !symbolTableOp->hasTrait()) return failure(); symbols.push_back(symbolTableOp); @@ -355,7 +356,7 @@ LogicalResult SymbolTable::lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr symbol, SmallVectorImpl &symbols) { - auto lookupFn = [](Operation *symbolTableOp, StringRef symbol) { + auto lookupFn = [](Operation *symbolTableOp, StringAttr symbol) { return lookupSymbolIn(symbolTableOp, symbol); }; return lookupSymbolInImpl(symbolTableOp, symbol, symbols, lookupFn); @@ -365,7 +366,7 @@ /// closes parent operation with the 'OpTrait::SymbolTable' trait. Returns /// nullptr if no valid symbol was found. Operation *SymbolTable::lookupNearestSymbolFrom(Operation *from, - StringRef symbol) { + StringAttr symbol) { Operation *symbolTableOp = getNearestSymbolTable(from); return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr; } @@ -610,7 +611,7 @@ /// Collect all of the symbol scopes from 'symbol' to (inclusive) 'limit'. static SmallVector collectSymbolScopes(Operation *symbol, Operation *limit) { - StringRef symName = SymbolTable::getSymbolName(symbol); + StringAttr symName = SymbolTable::getSymbolName(symbol); assert(!symbol->hasTrait() || symbol != limit); // Compute the ancestors of 'limit'. @@ -625,7 +626,7 @@ // doesn't support parent references. if (SymbolTable::getNearestSymbolTable(limit->getParentOp()) == symbol->getParentOp()) - return {{SymbolRefAttr::get(symbol->getContext(), symName), limit}}; + return {{SymbolRefAttr::get(symName), limit}}; return {}; } @@ -679,9 +680,9 @@ return scopes; } template -static SmallVector collectSymbolScopes(StringRef symbol, +static SmallVector collectSymbolScopes(StringAttr symbol, IRUnit *limit) { - return {{SymbolRefAttr::get(limit->getContext(), symbol), limit}}; + return {{SymbolRefAttr::get(symbol), limit}}; } /// Returns true if the given reference 'SubRef' is a sub reference of the @@ -753,7 +754,7 @@ /// operation 'from', invoking the provided callback for each. This does not /// traverse into any nested symbol tables. This function returns None if there /// are any unknown operations that may potentially be symbol tables. -auto SymbolTable::getSymbolUses(StringRef symbol, Operation *from) +auto SymbolTable::getSymbolUses(StringAttr symbol, Operation *from) -> Optional { return getSymbolUsesImpl(symbol, from); } @@ -761,7 +762,7 @@ -> Optional { return getSymbolUsesImpl(symbol, from); } -auto SymbolTable::getSymbolUses(StringRef symbol, Region *from) +auto SymbolTable::getSymbolUses(StringAttr symbol, Region *from) -> Optional { return getSymbolUsesImpl(symbol, from); } @@ -792,13 +793,13 @@ /// the given operation 'from'. This does not traverse into any nested symbol /// tables. This function will also return false if there are any unknown /// operations that may potentially be symbol tables. -bool SymbolTable::symbolKnownUseEmpty(StringRef symbol, Operation *from) { +bool SymbolTable::symbolKnownUseEmpty(StringAttr symbol, Operation *from) { return symbolKnownUseEmptyImpl(symbol, from); } bool SymbolTable::symbolKnownUseEmpty(Operation *symbol, Operation *from) { return symbolKnownUseEmptyImpl(symbol, from); } -bool SymbolTable::symbolKnownUseEmpty(StringRef symbol, Region *from) { +bool SymbolTable::symbolKnownUseEmpty(StringAttr symbol, Region *from) { return symbolKnownUseEmptyImpl(symbol, from); } bool SymbolTable::symbolKnownUseEmpty(Operation *symbol, Region *from) { @@ -861,14 +862,13 @@ return newLeafAttr; auto nestedRefs = llvm::to_vector<2>(oldAttr.getNestedReferences()); nestedRefs.back() = newLeafAttr; - return SymbolRefAttr::get(oldAttr.getContext(), oldAttr.getRootReference(), - nestedRefs); + return SymbolRefAttr::get(oldAttr.getRootReference(), nestedRefs); } /// The implementation of SymbolTable::replaceAllSymbolUses below. template static LogicalResult -replaceAllSymbolUsesImpl(SymbolT symbol, StringRef newSymbol, IRUnitT *limit) { +replaceAllSymbolUsesImpl(SymbolT symbol, StringAttr newSymbol, IRUnitT *limit) { // A collection of operations along with their new attribute dictionary. std::vector> updatedAttrDicts; @@ -888,8 +888,7 @@ }; // Generate a new attribute to replace the given attribute. - MLIRContext *ctx = limit->getContext(); - FlatSymbolRefAttr newLeafAttr = FlatSymbolRefAttr::get(ctx, newSymbol); + FlatSymbolRefAttr newLeafAttr = FlatSymbolRefAttr::get(newSymbol); for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) { SymbolRefAttr newAttr = generateNewRefAttr(scope.symbol, newLeafAttr); auto walkFn = [&](SymbolTable::SymbolUse symbolUse, @@ -905,13 +904,13 @@ if (useRef != scope.symbol) { if (scope.symbol.isa()) { replacementRef = - SymbolRefAttr::get(ctx, newSymbol, useRef.getNestedReferences()); + SymbolRefAttr::get(newSymbol, useRef.getNestedReferences()); } else { auto nestedRefs = llvm::to_vector<4>(useRef.getNestedReferences()); nestedRefs[scope.symbol.getNestedReferences().size() - 1] = newLeafAttr; replacementRef = - SymbolRefAttr::get(ctx, useRef.getRootReference(), nestedRefs); + SymbolRefAttr::get(useRef.getRootReference(), nestedRefs); } } @@ -949,23 +948,23 @@ /// 'from'. This does not traverse into any nested symbol tables. If there are /// any unknown operations that may potentially be symbol tables, no uses are /// replaced and failure is returned. -LogicalResult SymbolTable::replaceAllSymbolUses(StringRef oldSymbol, - StringRef newSymbol, +LogicalResult SymbolTable::replaceAllSymbolUses(StringAttr oldSymbol, + StringAttr newSymbol, Operation *from) { return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from); } LogicalResult SymbolTable::replaceAllSymbolUses(Operation *oldSymbol, - StringRef newSymbol, + StringAttr newSymbol, Operation *from) { return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from); } -LogicalResult SymbolTable::replaceAllSymbolUses(StringRef oldSymbol, - StringRef newSymbol, +LogicalResult SymbolTable::replaceAllSymbolUses(StringAttr oldSymbol, + StringAttr newSymbol, Region *from) { return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from); } LogicalResult SymbolTable::replaceAllSymbolUses(Operation *oldSymbol, - StringRef newSymbol, + StringAttr newSymbol, Region *from) { return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from); } @@ -975,7 +974,7 @@ //===----------------------------------------------------------------------===// Operation *SymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp, - StringRef symbol) { + StringAttr symbol) { return getSymbolTable(symbolTableOp).lookup(symbol); } Operation *SymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp, @@ -992,7 +991,7 @@ SymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr name, SmallVectorImpl &symbols) { - auto lookupFn = [this](Operation *symbolTableOp, StringRef symbol) { + auto lookupFn = [this](Operation *symbolTableOp, StringAttr symbol) { return lookupSymbolIn(symbolTableOp, symbol); }; return lookupSymbolInImpl(symbolTableOp, name, symbols, lookupFn); @@ -1003,7 +1002,7 @@ /// 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol was /// found. Operation *SymbolTableCollection::lookupNearestSymbolFrom(Operation *from, - StringRef symbol) { + StringAttr symbol) { Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(from); return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr; } @@ -1052,7 +1051,7 @@ } void SymbolUserMap::replaceAllUsesWith(Operation *symbol, - StringRef newSymbolName) { + StringAttr newSymbolName) { auto it = symbolToUsers.find(symbol); if (it == symbolToUsers.end()) return; diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp --- a/mlir/lib/Rewrite/ByteCode.cpp +++ b/mlir/lib/Rewrite/ByteCode.cpp @@ -818,7 +818,7 @@ void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) { ByteCodeField patternIndex = patterns.size(); patterns.emplace_back(PDLByteCodePattern::create( - op, rewriterToAddr[op.rewriter().getLeafReference()])); + op, rewriterToAddr[op.rewriter().getLeafReference().getValue()])); writer.append(OpCode::RecordMatch, patternIndex, SuccessorRange(op.getOperation()), op.matchedOps()); writer.appendPDLValueList(op.inputs()); diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -814,8 +814,8 @@ llvm::MDNode * ModuleTranslation::getAliasScope(Operation &opInst, SymbolRefAttr aliasScopeRef) const { - StringRef metadataName = aliasScopeRef.getRootReference(); - StringRef scopeName = aliasScopeRef.getLeafReference(); + StringAttr metadataName = aliasScopeRef.getRootReference(); + StringAttr scopeName = aliasScopeRef.getLeafReference(); auto metadataOp = SymbolTable::lookupNearestSymbolFrom( opInst.getParentOp(), metadataName); Operation *aliasScopeOp = diff --git a/mlir/test/lib/IR/TestSymbolUses.cpp b/mlir/test/lib/IR/TestSymbolUses.cpp --- a/mlir/test/lib/IR/TestSymbolUses.cpp +++ b/mlir/test/lib/IR/TestSymbolUses.cpp @@ -84,7 +84,7 @@ table.erase(op); assert(!table.lookup(name) && "expected erased operation to be unknown now"); - module.emitRemark() << name << " function successfully erased"; + module.emitRemark() << name.getValue() << " function successfully erased"; } } }; @@ -110,8 +110,8 @@ StringAttr newName = nestedOp->getAttrOfType("sym.new_name"); if (!newName) return; - symbolUsers.replaceAllUsesWith(nestedOp, newName.getValue()); - SymbolTable::setSymbolName(nestedOp, newName.getValue()); + symbolUsers.replaceAllUsesWith(nestedOp, newName); + SymbolTable::setSymbolName(nestedOp, newName); }); } }; diff --git a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp --- a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp +++ b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp @@ -80,8 +80,10 @@ // The test cases are encompassed via two modules, one containing the // patterns and one containing the operations to rewrite. - ModuleOp patternModule = module.lookupSymbol("patterns"); - ModuleOp irModule = module.lookupSymbol("ir"); + ModuleOp patternModule = module.lookupSymbol( + StringAttr::get(module->getContext(), "patterns")); + ModuleOp irModule = module.lookupSymbol( + StringAttr::get(module->getContext(), "ir")); if (!patternModule || !irModule) return;