diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -2545,7 +2545,7 @@ [{ $_state.addOperands(operands); $_state.addAttribute(calleeAttrName($_state.name), - $_builder.getSymbolRefAttr(callee)); + SymbolRefAttr::get(callee)); $_state.addTypes(callee.getType().getResults()); }]>, OpBuilder<(ins "mlir::SymbolRefAttr":$callee, @@ -2560,7 +2560,8 @@ "llvm::ArrayRef":$results, CArg<"mlir::ValueRange", "{}">:$operands), [{ - build($_builder, $_state, $_builder.getSymbolRefAttr(callee), results, + build($_builder, $_state, + SymbolRefAttr::get($_builder.getContext(), callee), results, operands); }]>]; diff --git a/flang/lib/Lower/IntrinsicCall.cpp b/flang/lib/Lower/IntrinsicCall.cpp --- a/flang/lib/Lower/IntrinsicCall.cpp +++ b/flang/lib/Lower/IntrinsicCall.cpp @@ -919,7 +919,7 @@ funcOp = getWrapper(rtCallGenerator, name, signature, loadRefArguments); } - return builder.getSymbolRefAttr(funcOp.getName()); + return SymbolRefAttr::get(funcOp); } //===----------------------------------------------------------------------===// diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -398,8 +398,7 @@ //===----------------------------------------------------------------------===// void fir::ConvertOp::getCanonicalizationPatterns( - OwningRewritePatternList &results, MLIRContext *context) { -} + OwningRewritePatternList &results, MLIRContext *context) {} mlir::OpFoldResult fir::ConvertOp::fold(llvm::ArrayRef opnds) { if (value().getType() == getType()) @@ -629,7 +628,8 @@ result.addAttribute(typeAttrName(result.name), mlir::TypeAttr::get(type)); result.addAttribute(mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)); - result.addAttribute(symbolAttrName(), builder.getSymbolRefAttr(name)); + result.addAttribute(symbolAttrName(), + SymbolRefAttr::get(builder.getContext(), name)); if (isConstant) result.addAttribute(constantAttrName(result.name), builder.getUnitAttr()); if (initialVal) @@ -1330,7 +1330,7 @@ template static A getSubOperands(unsigned pos, A allArgs, mlir::DenseIntElementsAttr ranges, - AdditionalArgs &&... additionalArgs) { + AdditionalArgs &&...additionalArgs) { unsigned start = 0; for (unsigned i = 0; i < pos; ++i) start += (*(ranges.begin() + i)).getZExtValue(); diff --git a/mlir/examples/toy/Ch2/mlir/Dialect.cpp b/mlir/examples/toy/Ch2/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch2/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch2/mlir/Dialect.cpp @@ -174,7 +174,8 @@ // Generic call always returns an unranked Tensor initially. state.addTypes(UnrankedTensorType::get(builder.getF64Type())); state.addOperands(arguments); - state.addAttribute("callee", builder.getSymbolRefAttr(callee)); + state.addAttribute("callee", + mlir::SymbolRefAttr::get(builder.getContext(), callee)); } //===----------------------------------------------------------------------===// diff --git a/mlir/examples/toy/Ch3/mlir/Dialect.cpp b/mlir/examples/toy/Ch3/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch3/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch3/mlir/Dialect.cpp @@ -174,7 +174,8 @@ // Generic call always returns an unranked Tensor initially. state.addTypes(UnrankedTensorType::get(builder.getF64Type())); state.addOperands(arguments); - state.addAttribute("callee", builder.getSymbolRefAttr(callee)); + state.addAttribute("callee", + mlir::SymbolRefAttr::get(builder.getContext(), callee)); } //===----------------------------------------------------------------------===// diff --git a/mlir/examples/toy/Ch4/mlir/Dialect.cpp b/mlir/examples/toy/Ch4/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch4/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch4/mlir/Dialect.cpp @@ -256,7 +256,8 @@ // Generic call always returns an unranked Tensor initially. state.addTypes(UnrankedTensorType::get(builder.getF64Type())); state.addOperands(arguments); - state.addAttribute("callee", builder.getSymbolRefAttr(callee)); + state.addAttribute("callee", + mlir::SymbolRefAttr::get(builder.getContext(), callee)); } /// Return the callee of the generic call operation, this is required by the diff --git a/mlir/examples/toy/Ch5/mlir/Dialect.cpp b/mlir/examples/toy/Ch5/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch5/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch5/mlir/Dialect.cpp @@ -256,7 +256,8 @@ // Generic call always returns an unranked Tensor initially. state.addTypes(UnrankedTensorType::get(builder.getF64Type())); state.addOperands(arguments); - state.addAttribute("callee", builder.getSymbolRefAttr(callee)); + state.addAttribute("callee", + mlir::SymbolRefAttr::get(builder.getContext(), callee)); } /// Return the callee of the generic call operation, this is required by the diff --git a/mlir/examples/toy/Ch6/mlir/Dialect.cpp b/mlir/examples/toy/Ch6/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch6/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch6/mlir/Dialect.cpp @@ -256,7 +256,8 @@ // Generic call always returns an unranked Tensor initially. state.addTypes(UnrankedTensorType::get(builder.getF64Type())); state.addOperands(arguments); - state.addAttribute("callee", builder.getSymbolRefAttr(callee)); + state.addAttribute("callee", + mlir::SymbolRefAttr::get(builder.getContext(), callee)); } /// Return the callee of the generic call operation, this is required by the diff --git a/mlir/examples/toy/Ch7/mlir/Dialect.cpp b/mlir/examples/toy/Ch7/mlir/Dialect.cpp --- a/mlir/examples/toy/Ch7/mlir/Dialect.cpp +++ b/mlir/examples/toy/Ch7/mlir/Dialect.cpp @@ -282,7 +282,8 @@ // Generic call always returns an unranked Tensor initially. state.addTypes(UnrankedTensorType::get(builder.getF64Type())); state.addOperands(arguments); - state.addAttribute("callee", builder.getSymbolRefAttr(callee)); + state.addAttribute("callee", + mlir::SymbolRefAttr::get(builder.getContext(), callee)); } /// Return the callee of the generic call operation, this is required by the diff --git a/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp b/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp --- a/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp +++ b/mlir/examples/toy/Ch7/mlir/MLIRGen.cpp @@ -522,7 +522,7 @@ mlir::FuncOp calledFunc = calledFuncIt->second; return builder.create( location, calledFunc.getType().getResult(0), - builder.getSymbolRefAttr(callee), operands); + mlir::SymbolRefAttr::get(builder.getContext(), callee), operands); } /// Emit a print expression. It emits specific operations for two builtins: diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -515,14 +515,22 @@ let results = (outs Variadic); let builders = [ OpBuilder<(ins "LLVMFuncOp":$func, "ValueRange":$operands, - CArg<"ArrayRef", "{}">:$attributes), - [{ + CArg<"ArrayRef", "{}">:$attributes), [{ Type resultType = func.getType().getReturnType(); if (!resultType.isa()) $_state.addTypes(resultType); - $_state.addAttribute("callee", $_builder.getSymbolRefAttr(func)); + $_state.addAttribute("callee", SymbolRefAttr::get(func)); $_state.addAttributes(attributes); $_state.addOperands(operands); + }]>, + OpBuilder<(ins "TypeRange":$results, "StringAttr":$callee, + CArg<"ValueRange", "{}">:$operands), [{ + build($_builder, $_state, results, SymbolRefAttr::get(callee), operands); + }]>, + OpBuilder<(ins "TypeRange":$results, "StringRef":$callee, + CArg<"ValueRange", "{}">:$operands), [{ + build($_builder, $_state, results, + StringAttr::get($_builder.getContext(), callee), operands); }]>]; let verifier = [{ return ::verify(*this); }]; let parser = [{ return parseCallOp(parser, result); }]; diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -560,7 +560,7 @@ let builders = [ OpBuilder<(ins "FuncOp":$callee, CArg<"ValueRange", "{}">:$operands), [{ $_state.addOperands(operands); - $_state.addAttribute("callee",$_builder.getSymbolRefAttr(callee)); + $_state.addAttribute("callee", SymbolRefAttr::get(callee)); $_state.addTypes(callee.getType().getResults()); }]>, OpBuilder<(ins "SymbolRefAttr":$callee, "TypeRange":$results, @@ -569,14 +569,19 @@ $_state.addAttribute("callee", callee); $_state.addTypes(results); }]>, + OpBuilder<(ins "StringAttr":$callee, "TypeRange":$results, + CArg<"ValueRange", "{}">:$operands), [{ + build($_builder, $_state, SymbolRefAttr::get(callee), results, operands); + }]>, OpBuilder<(ins "StringRef":$callee, "TypeRange":$results, CArg<"ValueRange", "{}">:$operands), [{ - build($_builder, $_state, $_builder.getSymbolRefAttr(callee), results, - operands); + build($_builder, $_state, StringAttr::get($_builder.getContext(), callee), + results, operands); }]>]; let extraClassDeclaration = [{ StringRef getCallee() { return callee(); } + StringAttr getCalleeAttr() { return calleeAttr().getAttr(); } FunctionType getCalleeType(); /// Get the argument operands to the called function. diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -97,17 +97,6 @@ FloatAttr getFloatAttr(Type type, const APFloat &value); StringAttr getStringAttr(const Twine &bytes); ArrayAttr getArrayAttr(ArrayRef value); - FlatSymbolRefAttr getSymbolRefAttr(Operation *value); - FlatSymbolRefAttr getSymbolRefAttr(StringAttr value); - SymbolRefAttr getSymbolRefAttr(StringAttr value, - ArrayRef nestedReferences); - SymbolRefAttr getSymbolRefAttr(StringRef value, - ArrayRef nestedReferences) { - return getSymbolRefAttr(getStringAttr(value), nestedReferences); - } - FlatSymbolRefAttr getSymbolRefAttr(StringRef value) { - return getSymbolRefAttr(getStringAttr(value)); - } // Returns a 0-valued attribute of the given `type`. This function only // supports boolean, integer, and 16-/32-/64-bit float types, and vector or diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h --- a/mlir/include/mlir/IR/BuiltinAttributes.h +++ b/mlir/include/mlir/IR/BuiltinAttributes.h @@ -23,6 +23,7 @@ class IntegerSet; class IntegerType; class Location; +class Operation; class ShapedType; //===----------------------------------------------------------------------===// @@ -685,12 +686,17 @@ using ValueType = StringRef; /// Construct a symbol reference for the given value name. + static FlatSymbolRefAttr get(StringAttr value) { + return SymbolRefAttr::get(value); + } static FlatSymbolRefAttr get(MLIRContext *ctx, StringRef value) { return SymbolRefAttr::get(ctx, value); } - static FlatSymbolRefAttr get(StringAttr value) { - return SymbolRefAttr::get(value); + /// Convenience getter for building a SymbolRefAttr based on an operation + /// that implements the SymbolTrait. + static FlatSymbolRefAttr get(Operation *symbol) { + return SymbolRefAttr::get(symbol); } /// Returns the name of the held symbol reference as a StringAttr. diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -893,8 +893,16 @@ }]>, ]; let extraClassDeclaration = [{ - static FlatSymbolRefAttr get(MLIRContext *ctx, StringRef value); + static SymbolRefAttr get(MLIRContext *ctx, StringRef value, + ArrayRef nestedRefs); + /// Convenience getters for building a SymbolRefAttr with no path, which is + /// known to produce a FlatSymbolRefAttr. static FlatSymbolRefAttr get(StringAttr value); + static FlatSymbolRefAttr get(MLIRContext *ctx, StringRef value); + + /// Convenience getter for buliding a SymbolRefAttr based on an operation + /// that implements the SymbolTrait. + static FlatSymbolRefAttr get(Operation *symbol); /// Returns the name of the fully resolved symbol, i.e. the leaf of the /// reference path. diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1582,15 +1582,16 @@ let storageType = [{ ::mlir::SymbolRefAttr }]; let returnType = [{ ::mlir::SymbolRefAttr }]; let valueType = NoneType; - let constBuilderCall = "$_builder.getSymbolRefAttr($0)"; + let constBuilderCall = "SymbolRefAttr::get($_builder.getContext(), $0)"; let convertFromStorage = "$_self"; } + def FlatSymbolRefAttr : Attr()">, "flat symbol reference attribute"> { let storageType = [{ ::mlir::FlatSymbolRefAttr }]; let returnType = [{ ::llvm::StringRef }]; let valueType = NoneType; - let constBuilderCall = "$_builder.getSymbolRefAttr($0)"; + let constBuilderCall = "SymbolRefAttr::get($_builder.getContext(), $0)"; let convertFromStorage = "$_self.getValue()"; } diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -367,7 +367,7 @@ // Allocate memory for the coroutine frame. auto coroAlloc = rewriter.create( - loc, i8Ptr, rewriter.getSymbolRefAttr(kMalloc), + loc, i8Ptr, SymbolRefAttr::get(rewriter.getContext(), kMalloc), ValueRange(coroSize.getResult())); // Begin a coroutine: @llvm.coro.begin. @@ -399,9 +399,9 @@ auto coroMem = rewriter.create(loc, i8Ptr, operands); // Free the memory. - rewriter.replaceOpWithNewOp(op, TypeRange(), - rewriter.getSymbolRefAttr(kFree), - ValueRange(coroMem.getResult())); + rewriter.replaceOpWithNewOp( + op, TypeRange(), SymbolRefAttr::get(rewriter.getContext(), kFree), + ValueRange(coroMem.getResult())); return success(); } diff --git a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h --- a/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h +++ b/mlir/lib/Conversion/GPUCommon/OpToFuncCallLowering.h @@ -62,8 +62,7 @@ LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op); auto callOp = rewriter.create( - op->getLoc(), resultType, rewriter.getSymbolRefAttr(funcOp), - castedOperands); + op->getLoc(), resultType, SymbolRefAttr::get(funcOp), castedOperands); if (resultType == operands.front().getType()) { rewriter.replaceOp(op, {callOp.getResult(0)}); diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp --- a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp +++ b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp @@ -171,14 +171,13 @@ // Create vulkan launch call op. auto vulkanLaunchCallOp = builder.create( - loc, TypeRange{}, builder.getSymbolRefAttr(kVulkanLaunch), + loc, TypeRange{}, SymbolRefAttr::get(builder.getContext(), kVulkanLaunch), vulkanLaunchOperands); // Set SPIR-V binary shader data as an attribute. vulkanLaunchCallOp->setAttr( kSPIRVBlobAttrName, - StringAttr::get(loc->getContext(), - StringRef(binary.data(), binary.size()))); + builder.getStringAttr(StringRef(binary.data(), binary.size()))); // Set entry point name as an attribute. vulkanLaunchCallOp->setAttr(kSPIRVEntryPointAttrName, diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp --- a/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp +++ b/mlir/lib/Conversion/GPUToVulkan/ConvertLaunchFuncToVulkanCalls.cpp @@ -248,9 +248,7 @@ } // Create call to `bindMemRef`. builder.create( - loc, TypeRange(), - builder.getSymbolRefAttr( - StringRef(symbolName.data(), symbolName.size())), + loc, TypeRange(), StringRef(symbolName.data(), symbolName.size()), ValueRange{vulkanRuntime, descriptorSet, descriptorBinding, ptrToMemRefDescriptor}); } @@ -373,8 +371,7 @@ Location loc = cInterfaceVulkanLaunchCallOp.getLoc(); // Create call to `initVulkan`. auto initVulkanCall = builder.create( - loc, TypeRange{getPointerType()}, builder.getSymbolRefAttr(kInitVulkan), - ValueRange{}); + loc, TypeRange{getPointerType()}, kInitVulkan); // The result of `initVulkan` function is a pointer to Vulkan runtime, we // need to pass that pointer to each Vulkan runtime call. auto vulkanRuntime = initVulkanCall.getResult(0); @@ -396,32 +393,29 @@ // Create call to `setBinaryShader` runtime function with the given pointer to // SPIR-V binary and binary size. builder.create( - loc, TypeRange(), builder.getSymbolRefAttr(kSetBinaryShader), + loc, TypeRange(), kSetBinaryShader, ValueRange{vulkanRuntime, ptrToSPIRVBinary, binarySize}); // Create LLVM global with entry point name. Value entryPointName = createEntryPointNameConstant( spirvAttributes.second.getValue(), loc, builder); // Create call to `setEntryPoint` runtime function with the given pointer to // entry point name. - builder.create(loc, TypeRange(), - builder.getSymbolRefAttr(kSetEntryPoint), + builder.create(loc, TypeRange(), kSetEntryPoint, ValueRange{vulkanRuntime, entryPointName}); // Create number of local workgroup for each dimension. builder.create( - loc, TypeRange(), builder.getSymbolRefAttr(kSetNumWorkGroups), + loc, TypeRange(), kSetNumWorkGroups, ValueRange{vulkanRuntime, cInterfaceVulkanLaunchCallOp.getOperand(0), cInterfaceVulkanLaunchCallOp.getOperand(1), cInterfaceVulkanLaunchCallOp.getOperand(2)}); // Create call to `runOnVulkan` runtime function. - builder.create(loc, TypeRange(), - builder.getSymbolRefAttr(kRunOnVulkan), + builder.create(loc, TypeRange(), kRunOnVulkan, ValueRange{vulkanRuntime}); // Create call to 'deinitVulkan' runtime function. - builder.create(loc, TypeRange(), - builder.getSymbolRefAttr(kDeinitVulkan), + builder.create(loc, TypeRange(), kDeinitVulkan, ValueRange{vulkanRuntime}); // Declare runtime functions. diff --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp --- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp +++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp @@ -50,7 +50,8 @@ } // fnName is a dynamic std::string, unique it via a SymbolRefAttr. - FlatSymbolRefAttr fnNameAttr = rewriter.getSymbolRefAttr(fnName); + FlatSymbolRefAttr fnNameAttr = + SymbolRefAttr::get(rewriter.getContext(), fnName); auto module = op->getParentOfType(); if (module.lookupSymbol(fnNameAttr.getAttr())) return fnNameAttr; diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -305,7 +305,7 @@ op.getLoc(), getVoidPtrType(), memref.allocatedPtr(rewriter, op.getLoc())); rewriter.replaceOpWithNewOp( - op, TypeRange(), rewriter.getSymbolRefAttr(freeFunc), casted); + op, TypeRange(), SymbolRefAttr::get(freeFunc), casted); return success(); } }; diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp --- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp @@ -559,9 +559,10 @@ /*results=*/llvm::None)); builder.create(rewriter.getLoc()); - return builder.getSymbolRefAttr( + return SymbolRefAttr::get( + builder.getContext(), pdl_interp::PDLInterpDialect::getRewriterModuleName(), - builder.getSymbolRefAttr(rewriterFunc)); + SymbolRefAttr::get(rewriterFunc)); } void PatternLowering::generateRewriter( diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1194,8 +1194,8 @@ // Helper to emit a call. static void emitCall(ConversionPatternRewriter &rewriter, Location loc, Operation *ref, ValueRange params = ValueRange()) { - rewriter.create(loc, TypeRange(), - rewriter.getSymbolRefAttr(ref), params); + rewriter.create(loc, TypeRange(), SymbolRefAttr::get(ref), + params); } }; diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -542,8 +542,9 @@ blockSize.y, blockSize.z}); result.addOperands(kernelOperands); auto kernelModule = kernelFunc->getParentOfType(); - auto kernelSymbol = builder.getSymbolRefAttr( - kernelModule.getName(), {builder.getSymbolRefAttr(kernelFunc.getName())}); + auto kernelSymbol = + SymbolRefAttr::get(kernelModule.getNameAttr(), + {SymbolRefAttr::get(kernelFunc.getNameAttr())}); result.addAttribute(getKernelAttrName(), kernelSymbol); SmallVector segmentSizes(8, 1); segmentSizes.front() = 0; // Initially no async dependencies. diff --git a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/FunctionCallUtils.cpp @@ -129,7 +129,7 @@ ValueRange paramTypes, ArrayRef resultTypes) { return b - .create(loc, resultTypes, b.getSymbolRefAttr(fn), + .create(loc, resultTypes, SymbolRefAttr::get(fn), paramTypes) ->getResults(); } diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -1060,7 +1060,7 @@ void spirv::AddressOfOp::build(OpBuilder &builder, OperationState &state, spirv::GlobalVariableOp var) { - build(builder, state, var.type(), builder.getSymbolRefAttr(var)); + build(builder, state, var.type(), SymbolRefAttr::get(var)); } static LogicalResult verify(spirv::AddressOfOp addressOfOp) { @@ -1712,8 +1712,7 @@ ArrayRef interfaceVars) { build(builder, state, spirv::ExecutionModelAttr::get(builder.getContext(), executionModel), - builder.getSymbolRefAttr(function), - builder.getArrayAttr(interfaceVars)); + SymbolRefAttr::get(function), builder.getArrayAttr(interfaceVars)); } static ParseResult parseEntryPointOp(OpAsmParser &parser, @@ -1772,7 +1771,7 @@ spirv::FuncOp function, spirv::ExecutionMode executionMode, ArrayRef params) { - build(builder, state, builder.getSymbolRefAttr(function), + build(builder, state, SymbolRefAttr::get(function), spirv::ExecutionModeAttr::get(builder.getContext(), executionMode), builder.getI32ArrayAttr(params)); } diff --git a/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp @@ -68,7 +68,7 @@ auto varOp = spirvModule.lookupSymbol(varName); rewriter.replaceOpWithNewOp( - op, varOp.type(), rewriter.getSymbolRefAttr(varName.getAttr())); + op, varOp.type(), SymbolRefAttr::get(varName.getAttr())); return success(); } }; diff --git a/mlir/lib/Dialect/StandardOps/Transforms/DecomposeCallGraphTypes.cpp b/mlir/lib/Dialect/StandardOps/Transforms/DecomposeCallGraphTypes.cpp --- a/mlir/lib/Dialect/StandardOps/Transforms/DecomposeCallGraphTypes.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/DecomposeCallGraphTypes.cpp @@ -156,7 +156,7 @@ resultMapping.push_back(i); } - CallOp newCallOp = rewriter.create(op.getLoc(), op.getCallee(), + CallOp newCallOp = rewriter.create(op.getLoc(), op.getCalleeAttr(), newResultTypes, newOperands); // Build a replacement value for each result to replace its uses. If a diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -210,23 +210,6 @@ return ArrayAttr::get(context, value); } -FlatSymbolRefAttr Builder::getSymbolRefAttr(Operation *value) { - auto symName = - value->getAttrOfType(SymbolTable::getSymbolAttrName()); - assert(symName && "value does not have a valid symbol name"); - return getSymbolRefAttr(symName.getValue()); -} - -FlatSymbolRefAttr Builder::getSymbolRefAttr(StringAttr value) { - return SymbolRefAttr::get(value); -} - -SymbolRefAttr -Builder::getSymbolRefAttr(StringAttr value, - ArrayRef nestedReferences) { - return SymbolRefAttr::get(value, nestedReferences); -} - ArrayAttr Builder::getBoolArrayAttr(ArrayRef values) { auto attrs = llvm::to_vector<8>(llvm::map_range( values, [this](bool v) -> Attribute { return getBoolAttr(v); })); diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -10,14 +10,14 @@ #include "AttributeDetail.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinDialect.h" -#include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/IntegerSet.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/SymbolTable.h" #include "mlir/IR/Types.h" #include "mlir/Interfaces/DecodeAttributesInterfaces.h" #include "llvm/ADT/APSInt.h" #include "llvm/ADT/Sequence.h" -#include "llvm/ADT/Twine.h" #include "llvm/Support/Endian.h" using namespace mlir; @@ -272,14 +272,26 @@ // SymbolRefAttr //===----------------------------------------------------------------------===// +SymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value, + ArrayRef nestedRefs) { + return get(StringAttr::get(ctx, value), nestedRefs); +} + FlatSymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value) { - return get(StringAttr::get(ctx, value)); + return get(ctx, value, {}).cast(); } FlatSymbolRefAttr SymbolRefAttr::get(StringAttr value) { return get(value, {}).cast(); } +FlatSymbolRefAttr SymbolRefAttr::get(Operation *symbol) { + auto symName = + symbol->getAttrOfType(SymbolTable::getSymbolAttrName()); + assert(symName && "value does not have a valid symbol name"); + return SymbolRefAttr::get(symName); +} + StringAttr SymbolRefAttr::getLeafReference() const { ArrayRef nestedRefs = getNestedReferences(); return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getAttr(); diff --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp --- a/mlir/lib/Parser/AttributeParser.cpp +++ b/mlir/lib/Parser/AttributeParser.cpp @@ -191,7 +191,8 @@ consumeToken(Token::at_identifier); nestedRefs.push_back(SymbolRefAttr::get(getContext(), nameStr)); } - SymbolRefAttr symbolRefAttr = builder.getSymbolRefAttr(nameStr, nestedRefs); + SymbolRefAttr symbolRefAttr = + SymbolRefAttr::get(getContext(), nameStr, nestedRefs); // If we are populating the assembly state, record this symbol reference. if (state.asmState) diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -1406,9 +1406,8 @@ // If we are populating the assembly parser state, record this as a symbol // reference. if (parser.getState().asmState) { - parser.getState().asmState->addUses( - getBuilder().getSymbolRefAttr(result.getValue()), - atToken.getLocRange()); + parser.getState().asmState->addUses(SymbolRefAttr::get(result), + atToken.getLocRange()); } return success(); } diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp --- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp @@ -245,7 +245,7 @@ return b.getFloatAttr(FloatType::getF32(context), c->getValueAPF()); } if (auto *f = dyn_cast(value)) - return b.getSymbolRefAttr(f->getName()); + return SymbolRefAttr::get(b.getContext(), f->getName()); // Convert constant data to a dense elements attribute. if (auto *cd = dyn_cast(value)) { @@ -668,8 +668,8 @@ } Operation *op; if (llvm::Function *callee = ci->getCalledFunction()) { - op = b.create(loc, tys, b.getSymbolRefAttr(callee->getName()), - ops); + op = b.create( + loc, tys, SymbolRefAttr::get(b.getContext(), callee->getName()), ops); } else { Value calledValue = processValue(ci->getCalledOperand()); if (!calledValue) @@ -713,9 +713,10 @@ Operation *op; if (llvm::Function *callee = ii->getCalledFunction()) { - op = b.create(loc, tys, b.getSymbolRefAttr(callee->getName()), - ops, blocks[ii->getNormalDest()], normalArgs, - blocks[ii->getUnwindDest()], unwindArgs); + op = b.create( + loc, tys, SymbolRefAttr::get(b.getContext(), callee->getName()), ops, + blocks[ii->getNormalDest()], normalArgs, blocks[ii->getUnwindDest()], + unwindArgs); } else { ops.insert(ops.begin(), processValue(ii->getCalledOperand())); op = b.create(loc, tys, ops, blocks[ii->getNormalDest()], @@ -771,7 +772,7 @@ // If it directly has a name, we can use it. if (pf->hasName()) - return b.getSymbolRefAttr(pf->getName()); + return SymbolRefAttr::get(b.getContext(), pf->getName()); // If it doesn't have a name, currently, only function pointers that are // bitcast to i8* are parsed. @@ -779,7 +780,7 @@ if (ce->getOpcode() == llvm::Instruction::BitCast && ce->getType() == llvm::Type::getInt8PtrTy(f->getContext())) { if (auto func = dyn_cast(ce->getOperand(0))) - return b.getSymbolRefAttr(func->getName()); + return SymbolRefAttr::get(b.getContext(), func->getName()); } } return FlatSymbolRefAttr(); diff --git a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp --- a/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp @@ -44,20 +44,19 @@ } if (auto varOp = getGlobalVariable(id)) { auto addressOfOp = opBuilder.create( - unknownLoc, varOp.type(), - opBuilder.getSymbolRefAttr(varOp.getOperation())); + unknownLoc, varOp.type(), SymbolRefAttr::get(varOp.getOperation())); return addressOfOp.pointer(); } if (auto constOp = getSpecConstant(id)) { auto referenceOfOp = opBuilder.create( unknownLoc, constOp.default_value().getType(), - opBuilder.getSymbolRefAttr(constOp.getOperation())); + SymbolRefAttr::get(constOp.getOperation())); return referenceOfOp.reference(); } if (auto constCompositeOp = getSpecConstantComposite(id)) { auto referenceOfOp = opBuilder.create( unknownLoc, constCompositeOp.type(), - opBuilder.getSymbolRefAttr(constCompositeOp.getOperation())); + SymbolRefAttr::get(constCompositeOp.getOperation())); return referenceOfOp.reference(); } if (auto specConstOperationInfo = getSpecConstantOperation(id)) { @@ -357,12 +356,12 @@ return emitError(unknownLoc, "undefined result ") << words[wordIndex] << " while decoding OpEntryPoint"; } - interface.push_back(opBuilder.getSymbolRefAttr(arg.getOperation())); + interface.push_back(SymbolRefAttr::get(arg.getOperation())); wordIndex++; } - opBuilder.create(unknownLoc, execModel, - opBuilder.getSymbolRefAttr(fnName), - opBuilder.getArrayAttr(interface)); + opBuilder.create( + unknownLoc, execModel, SymbolRefAttr::get(opBuilder.getContext(), fnName), + opBuilder.getArrayAttr(interface)); return success(); } @@ -394,7 +393,8 @@ } auto values = opBuilder.getArrayAttr(attrListElems); opBuilder.create( - unknownLoc, opBuilder.getSymbolRefAttr(fn.getName()), execMode, values); + unknownLoc, SymbolRefAttr::get(opBuilder.getContext(), fn.getName()), + execMode, values); return success(); } @@ -461,8 +461,8 @@ } auto opFunctionCall = opBuilder.create( - unknownLoc, resultType, opBuilder.getSymbolRefAttr(functionName), - arguments); + unknownLoc, resultType, + SymbolRefAttr::get(opBuilder.getContext(), functionName), arguments); if (resultType) valueMap[resultID] = opFunctionCall.getResult(0); diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -575,7 +575,7 @@ << operands[wordIndex] << "used as initializer"; } wordIndex++; - initializer = opBuilder.getSymbolRefAttr(initializerOp.getOperation()); + initializer = SymbolRefAttr::get(initializerOp.getOperation()); } if (wordIndex != operands.size()) { return emitError(unknownLoc, @@ -1279,7 +1279,7 @@ elements.reserve(operands.size() - 2); for (unsigned i = 2, e = operands.size(); i < e; ++i) { auto elementInfo = getSpecConstant(operands[i]); - elements.push_back(opBuilder.getSymbolRefAttr(elementInfo)); + elements.push_back(SymbolRefAttr::get(elementInfo)); } auto op = opBuilder.create( diff --git a/mlir/lib/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Transforms/NormalizeMemRefs.cpp --- a/mlir/lib/Transforms/NormalizeMemRefs.cpp +++ b/mlir/lib/Transforms/NormalizeMemRefs.cpp @@ -129,10 +129,10 @@ // Functions called by this function. funcOp.walk([&](CallOp callOp) { - StringRef callee = callOp.getCallee(); + StringAttr callee = callOp.getCalleeAttr(); for (FuncOp &funcOp : normalizableFuncs) { // We compare FuncOp and callee's name. - if (callee == funcOp.getName()) { + if (callee == funcOp.getNameAttr()) { setCalleesAndCallersNonNormalizable(funcOp, moduleOp, normalizableFuncs); break; @@ -255,10 +255,9 @@ auto callOp = dyn_cast(userOp); if (!callOp) continue; - StringRef callee = callOp.getCallee(); - Operation *newCallOp = builder.create( - userOp->getLoc(), resultTypes, builder.getSymbolRefAttr(callee), - userOp->getOperands()); + Operation *newCallOp = + builder.create(userOp->getLoc(), callOp.getCalleeAttr(), + resultTypes, userOp->getOperands()); bool replacingMemRefUsesFailed = false; bool returnTypeChanged = false; for (unsigned resIndex : llvm::seq(0, userOp->getNumResults())) {