diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -2222,7 +2222,7 @@ NamedAttribute targetOffsetAttr = *owner->getAttrDictionary().getNamed(offsetAttr); return getSubOperands( - pos, operands, targetOffsetAttr.second.cast(), + pos, operands, targetOffsetAttr.getValue().cast(), mlir::MutableOperandRange::OperandSegment(pos, targetOffsetAttr)); } diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -72,9 +72,8 @@ [{ if (resultType) $_state.addTypes(resultType); $_state.addOperands(operands); - for (auto namedAttr : attributes) { - $_state.addAttribute(namedAttr.first, namedAttr.second); - } + for (auto namedAttr : attributes) + $_state.addAttribute(namedAttr.getName(), namedAttr.getValue()); }]>; def LLVM_ZeroResultOpBuilder : @@ -82,9 +81,8 @@ CArg<"ArrayRef", "{}">:$attributes), [{ $_state.addOperands(operands); - for (auto namedAttr : attributes) { - $_state.addAttribute(namedAttr.first, namedAttr.second); - } + for (auto namedAttr : attributes) + $_state.addAttribute(namedAttr.getName(), namedAttr.getValue()); }]>; // Compatibility builder that takes an instance of wrapped llvm::VoidType diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -136,13 +136,60 @@ // NamedAttribute //===----------------------------------------------------------------------===// -/// NamedAttribute is combination of a name, represented by a StringAttr, and a -/// value, represented by an Attribute. The attribute pointer should always be -/// non-null. -using NamedAttribute = std::pair; +/// NamedAttribute represents a combination of a name and an Attribute value. +class NamedAttribute { +public: + NamedAttribute(StringAttr name, Attribute value); + + /// Return the name of the attribute. + StringAttr getName() const; + + /// Return the dialect of the name of this attribute, if the name is prefixed + /// by a dialect namespace. For example, `llvm.fast_math` would return the + /// LLVM dialect (if it is loaded). Returns nullptr if the dialect isn't + /// loaded, or if the name is not prefixed by a dialect namespace. + Dialect *getNameDialect() const; + + /// Return the value of the attribute. + Attribute getValue() const { return value; } + + /// Set the name of this attribute. + void setName(StringAttr newName); + + /// Set the value of this attribute. + void setValue(Attribute newValue) { + assert(value && "expected valid attribute value"); + value = newValue; + } + + /// Compare this attribute to the provided attribute, ordering by name. + bool operator<(const NamedAttribute &rhs) const; + /// Compare this attribute to the provided string, ordering by name. + bool operator<(StringRef rhs) const; + + bool operator==(const NamedAttribute &rhs) const { + return name == rhs.name && value == rhs.value; + } + bool operator!=(const NamedAttribute &rhs) const { return !(*this == rhs); } + +private: + NamedAttribute(Attribute name, Attribute value) : name(name), value(value) {} -bool operator<(const NamedAttribute &lhs, const NamedAttribute &rhs); -bool operator<(const NamedAttribute &lhs, StringRef rhs); + /// Allow access to internals to enable hashing. + friend ::llvm::hash_code hash_value(const NamedAttribute &arg); + friend DenseMapInfo; + + /// The name of the attribute. This is represented as a StringAttr, but + /// type-erased to Attribute in the field. + Attribute name; + /// The value of the attribute. + Attribute value; +}; + +inline ::llvm::hash_code hash_value(const NamedAttribute &arg) { + using AttrPairT = std::pair; + return DenseMapInfo::getHashValue(AttrPairT(arg.name, arg.value)); +} //===----------------------------------------------------------------------===// // AttributeTraitBase @@ -227,6 +274,23 @@ mlir::AttributeStorage *>::NumLowBitsAvailable; }; +template <> struct DenseMapInfo { + static mlir::NamedAttribute getEmptyKey() { + auto emptyAttr = llvm::DenseMapInfo::getEmptyKey(); + return mlir::NamedAttribute(emptyAttr, emptyAttr); + } + static mlir::NamedAttribute getTombstoneKey() { + auto tombAttr = llvm::DenseMapInfo::getTombstoneKey(); + return mlir::NamedAttribute(tombAttr, tombAttr); + } + static unsigned getHashValue(mlir::NamedAttribute val) { + return mlir::hash_value(val); + } + static bool isEqual(mlir::NamedAttribute lhs, mlir::NamedAttribute rhs) { + return lhs == rhs; + } +}; + } // namespace llvm #endif diff --git a/mlir/include/mlir/IR/FunctionSupport.h b/mlir/include/mlir/IR/FunctionSupport.h --- a/mlir/include/mlir/IR/FunctionSupport.h +++ b/mlir/include/mlir/IR/FunctionSupport.h @@ -609,10 +609,10 @@ // that they contain a dialect prefix in their name. Call the dialect, if // registered, to verify the attributes themselves. for (auto attr : argAttrs) { - if (!attr.first.strref().contains('.')) + if (!attr.getName().strref().contains('.')) return funcOp.emitOpError( "arguments may only have dialect attributes"); - if (Dialect *dialect = attr.first.getReferencedDialect()) { + if (Dialect *dialect = attr.getNameDialect()) { if (failed(dialect->verifyRegionArgAttribute(op, /*regionIndex=*/0, /*argIndex=*/i, attr))) return failure(); @@ -643,9 +643,9 @@ // that they contain a dialect prefix in their name. Call the dialect, if // registered, to verify the attributes themselves. for (auto attr : resultAttrs) { - if (!attr.first.strref().contains('.')) + if (!attr.getName().strref().contains('.')) return funcOp.emitOpError("results may only have dialect attributes"); - if (Dialect *dialect = attr.first.getReferencedDialect()) { + if (Dialect *dialect = attr.getNameDialect()) { if (failed(dialect->verifyRegionResultAttribute(op, /*regionIndex=*/0, /*resultIndex=*/i, attr))) diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -373,7 +373,7 @@ bool (*)(NamedAttribute)> { static bool filter(NamedAttribute attr) { // Dialect attributes are prefixed by the dialect name, like operations. - return attr.first.strref().count('.'); + return attr.getName().strref().count('.'); } explicit dialect_attr_iterator(ArrayRef::iterator it, @@ -407,7 +407,7 @@ NamedAttrList attrs; attrs.append(std::begin(dialectAttrs), std::end(dialectAttrs)); for (auto attr : getAttrs()) - if (!attr.first.strref().contains('.')) + if (!attr.getName().strref().contains('.')) attrs.push_back(attr); setAttrs(attrs.getDictionary(getContext())); } diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -382,7 +382,7 @@ std::pair findAttrUnsorted(IteratorT first, IteratorT last, NameT name) { for (auto it = first; it != last; ++it) - if (it->first == name) + if (it->getName() == name) return {it, true}; return {last, false}; } @@ -399,7 +399,7 @@ while (length > 0) { ptrdiff_t half = length / 2; IteratorT mid = first + half; - int compare = mid->first.strref().compare(name); + int compare = mid->getName().strref().compare(name); if (compare < 0) { first = mid + 1; length = length - half - 1; diff --git a/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h b/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h --- a/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h +++ b/mlir/include/mlir/Target/LLVMIR/LLVMTranslationInterface.h @@ -81,7 +81,7 @@ amendOperation(Operation *op, NamedAttribute attribute, LLVM::ModuleTranslation &moduleTranslation) const { if (const LLVMTranslationDialectInterface *iface = - getInterfaceFor(attribute.first.getReferencedDialect())) { + getInterfaceFor(attribute.getNameDialect())) { return iface->amendOperation(op, attribute, moduleTranslation); } return success(); diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -83,7 +83,7 @@ intptr_t pos) { NamedAttribute attribute = unwrap(attr).cast().getValue()[pos]; - return {wrap(attribute.first), wrap(attribute.second)}; + return {wrap(attribute.getName()), wrap(attribute.getValue())}; } MlirAttribute mlirDictionaryAttrGetElementByName(MlirAttribute attr, diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -432,7 +432,7 @@ MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos) { NamedAttribute attr = unwrap(op)->getAttrs()[pos]; - return MlirNamedAttribute{wrap(attr.first), wrap(attr.second)}; + return MlirNamedAttribute{wrap(attr.getName()), wrap(attr.getValue())}; } MlirAttribute mlirOperationGetAttributeByName(MlirOperation op, diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp --- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp @@ -55,9 +55,9 @@ // not specific to function modeling. SmallVector attributes; for (const auto &attr : gpuFuncOp->getAttrs()) { - if (attr.first == SymbolTable::getSymbolAttrName() || - attr.first == function_like_impl::getTypeAttrName() || - attr.first == gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName()) + if (attr.getName() == SymbolTable::getSymbolAttrName() || + attr.getName() == function_like_impl::getTypeAttrName() || + attr.getName() == gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName()) continue; attributes.push_back(attr); } diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -216,10 +216,10 @@ rewriter.getFunctionType(signatureConverter.getConvertedTypes(), llvm::None)); for (const auto &namedAttr : funcOp->getAttrs()) { - if (namedAttr.first == function_like_impl::getTypeAttrName() || - namedAttr.first == SymbolTable::getSymbolAttrName()) + if (namedAttr.getName() == function_like_impl::getTypeAttrName() || + namedAttr.getName() == SymbolTable::getSymbolAttrName()) continue; - newFuncOp->setAttr(namedAttr.first, namedAttr.second); + newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue()); } rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp --- a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp +++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp @@ -544,10 +544,10 @@ // Propagate custom user defined optional attributes, that can be used at // later stage, such as extension data for GPU kernel dispatch for (const auto &namedAttr : parallelOp->getAttrs()) { - if (namedAttr.first == gpu::getMappingAttrName() || - namedAttr.first == ParallelOp::getOperandSegmentSizeAttr()) + if (namedAttr.getName() == gpu::getMappingAttrName() || + namedAttr.getName() == ParallelOp::getOperandSegmentSizeAttr()) continue; - launchOp->setAttr(namedAttr.first, namedAttr.second); + launchOp->setAttr(namedAttr.getName(), namedAttr.getValue()); } Block *body = parallelOp.getBody(); diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -53,11 +53,11 @@ bool filterArgAttrs, SmallVectorImpl &result) { for (const auto &attr : attrs) { - if (attr.first == SymbolTable::getSymbolAttrName() || - attr.first == function_like_impl::getTypeAttrName() || - attr.first == "std.varargs" || + if (attr.getName() == SymbolTable::getSymbolAttrName() || + attr.getName() == function_like_impl::getTypeAttrName() || + attr.getName() == "std.varargs" || (filterArgAttrs && - attr.first == function_like_impl::getArgDictAttrName())) + attr.getName() == function_like_impl::getArgDictAttrName())) continue; result.push_back(attr); } @@ -255,7 +255,7 @@ rewriter.getArrayAttr(newArgAttrs))); } for (auto pair : llvm::enumerate(attributes)) { - if (pair.value().first == "llvm.linkage") { + if (pair.value().getName() == "llvm.linkage") { attributes.erase(attributes.begin() + pair.index()); break; } @@ -448,9 +448,9 @@ auto newOp = rewriter.create(op.getLoc(), type, symbolRef.getValue()); for (const NamedAttribute &attr : op->getAttrs()) { - if (attr.first.strref() == "value") + if (attr.getName().strref() == "value") continue; - newOp->setAttr(attr.first, attr.second); + newOp->setAttr(attr.getName(), attr.getValue()); } rewriter.replaceOp(op, newOp->getResults()); return success(); diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp --- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp @@ -90,10 +90,10 @@ SmallVector opsToSimplify; func.walk([&](Operation *op) { for (auto attr : op->getAttrs()) { - if (auto mapAttr = attr.second.dyn_cast()) - simplifyAndUpdateAttribute(op, attr.first, mapAttr); - else if (auto setAttr = attr.second.dyn_cast()) - simplifyAndUpdateAttribute(op, attr.first, setAttr); + if (auto mapAttr = attr.getValue().dyn_cast()) + simplifyAndUpdateAttribute(op, attr.getName(), mapAttr); + else if (auto setAttr = attr.getValue().dyn_cast()) + simplifyAndUpdateAttribute(op, attr.getName(), setAttr); } if (isa(op)) diff --git a/mlir/lib/Dialect/DLTI/DLTI.cpp b/mlir/lib/Dialect/DLTI/DLTI.cpp --- a/mlir/lib/Dialect/DLTI/DLTI.cpp +++ b/mlir/lib/Dialect/DLTI/DLTI.cpp @@ -367,8 +367,8 @@ LogicalResult DLTIDialect::verifyOperationAttribute(Operation *op, NamedAttribute attr) { - if (attr.first == DLTIDialect::kDataLayoutAttrName) { - if (!attr.second.isa()) { + if (attr.getName() == DLTIDialect::kDataLayoutAttrName) { + if (!attr.getValue().isa()) { return op->emitError() << "'" << DLTIDialect::kDataLayoutAttrName << "' is expected to be a #dlti.dl_spec attribute"; } @@ -377,6 +377,6 @@ return success(); } - return op->emitError() << "attribute '" << attr.first.getValue() + return op->emitError() << "attribute '" << attr.getName().getValue() << "' not supported by dialect"; } diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -174,8 +174,8 @@ LogicalResult GPUDialect::verifyOperationAttribute(Operation *op, NamedAttribute attr) { - if (!attr.second.isa() || - attr.first != getContainerModuleAttrName()) + if (!attr.getValue().isa() || + attr.getName() != getContainerModuleAttrName()) return success(); auto module = dyn_cast(op); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -51,9 +51,9 @@ static auto processFMFAttr(ArrayRef attrs) { SmallVector filteredAttrs( llvm::make_filter_range(attrs, [&](NamedAttribute attr) { - if (attr.first == "fastmathFlags") { - auto defAttr = FMFAttr::get(attr.second.getContext(), {}); - return defAttr != attr.second; + if (attr.getName() == "fastmathFlags") { + auto defAttr = FMFAttr::get(attr.getValue().getContext(), {}); + return defAttr != attr.getValue(); } return true; })); @@ -201,7 +201,8 @@ Optional alignmentAttr = result.attributes.getNamed("alignment"); if (alignmentAttr.hasValue()) { - auto alignmentInt = alignmentAttr.getValue().second.dyn_cast(); + auto alignmentInt = + alignmentAttr.getValue().getValue().dyn_cast(); if (!alignmentInt) return parser.emitError(parser.getNameLoc(), "expected integer alignment"); @@ -2317,15 +2318,15 @@ NamedAttribute attr) { // If the `llvm.loop` attribute is present, enforce the following structure, // which the module translation can assume. - if (attr.first.strref() == LLVMDialect::getLoopAttrName()) { - auto loopAttr = attr.second.dyn_cast(); + if (attr.getName() == LLVMDialect::getLoopAttrName()) { + auto loopAttr = attr.getValue().dyn_cast(); if (!loopAttr) return op->emitOpError() << "expected '" << LLVMDialect::getLoopAttrName() << "' to be a dictionary attribute"; Optional parallelAccessGroup = loopAttr.getNamed(LLVMDialect::getParallelAccessAttrName()); if (parallelAccessGroup.hasValue()) { - auto accessGroups = parallelAccessGroup->second.dyn_cast(); + auto accessGroups = parallelAccessGroup->getValue().dyn_cast(); if (!accessGroups) return op->emitOpError() << "expected '" << LLVMDialect::getParallelAccessAttrName() @@ -2353,7 +2354,8 @@ Optional loopOptions = loopAttr.getNamed(LLVMDialect::getLoopOptionsAttrName()); - if (loopOptions.hasValue() && !loopOptions->second.isa()) + if (loopOptions.hasValue() && + !loopOptions->getValue().isa()) return op->emitOpError() << "expected '" << LLVMDialect::getLoopOptionsAttrName() << "' to be a `loopopts` attribute"; @@ -2363,9 +2365,9 @@ // syntax. Try parsing it and report errors in case of failure. Users of this // attribute may assume it is well-formed and can pass it to the (asserting) // llvm::DataLayout constructor. - if (attr.first.strref() != LLVM::LLVMDialect::getDataLayoutAttrName()) + if (attr.getName() != LLVM::LLVMDialect::getDataLayoutAttrName()) return success(); - if (auto stringAttr = attr.second.dyn_cast()) + if (auto stringAttr = attr.getValue().dyn_cast()) return verifyDataLayoutString( stringAttr.getValue(), [op](const Twine &message) { op->emitOpError() << message.str(); }); @@ -2381,13 +2383,13 @@ unsigned argIdx, NamedAttribute argAttr) { // Check that llvm.noalias is a unit attribute. - if (argAttr.first == LLVMDialect::getNoAliasAttrName() && - !argAttr.second.isa()) + if (argAttr.getName() == LLVMDialect::getNoAliasAttrName() && + !argAttr.getValue().isa()) return op->emitError() << "expected llvm.noalias argument attribute to be a unit attribute"; // Check that llvm.align is an integer attribute. - if (argAttr.first == LLVMDialect::getAlignAttrName() && - !argAttr.second.isa()) + if (argAttr.getName() == LLVMDialect::getAlignAttrName() && + !argAttr.getValue().isa()) return op->emitError() << "llvm.align argument attribute of non integer type"; return success(); diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -57,7 +57,7 @@ return failure(); for (auto &attr : result.attributes) { - if (attr.first != "return_value_and_is_valid") + if (attr.getName() != "return_value_and_is_valid") continue; auto structType = resultType.dyn_cast(); if (structType && !structType.getBody().empty()) @@ -249,7 +249,7 @@ LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op, NamedAttribute attr) { // Kernel function attribute should be attached to functions. - if (attr.first == NVVMDialect::getKernelFuncAttrName()) { + if (attr.getName() == NVVMDialect::getKernelFuncAttrName()) { if (!isa(op)) { return op->emitError() << "'" << NVVMDialect::getKernelFuncAttrName() << "' attribute attached to unexpected op"; diff --git a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp @@ -96,7 +96,7 @@ LogicalResult ROCDLDialect::verifyOperationAttribute(Operation *op, NamedAttribute attr) { // Kernel function attribute should be attached to functions. - if (attr.first == ROCDLDialect::getKernelFuncAttrName()) { + if (attr.getName() == ROCDLDialect::getKernelFuncAttrName()) { if (!isa(op)) { return op->emitError() << "'" << ROCDLDialect::getKernelFuncAttrName() << "' attribute attached to unexpected op"; diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -583,7 +583,7 @@ genericAttrNamesSet.insert(genericAttrNames.begin(), genericAttrNames.end()); SmallVector genericAttrs; for (auto attr : op->getAttrs()) - if (genericAttrNamesSet.count(attr.first.strref()) > 0) + if (genericAttrNamesSet.count(attr.getName().strref()) > 0) genericAttrs.push_back(attr); if (!genericAttrs.empty()) { auto genericDictAttr = DictionaryAttr::get(op.getContext(), genericAttrs); @@ -598,7 +598,7 @@ bool hasExtraAttrs = false; for (NamedAttribute n : op->getAttrs()) { - if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.first.strref()))) + if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref()))) break; } if (hasExtraAttrs) { @@ -753,8 +753,8 @@ // Copy over unknown attributes. They might be load bearing for some flow. ArrayRef odsAttrs = genericOp.getAttributeNames(); for (NamedAttribute kv : genericOp->getAttrs()) { - if (!llvm::is_contained(odsAttrs, kv.first.getValue())) { - newOp->setAttr(kv.first, kv.second); + if (!llvm::is_contained(odsAttrs, kv.getName().getValue())) { + newOp->setAttr(kv.getName(), kv.getValue()); } } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp @@ -152,30 +152,30 @@ NamedAttribute attr) { using comprehensive_bufferize::BufferizableOpInterface; - if (attr.first == BufferizableOpInterface::kInplaceableAttrName) { - if (!attr.second.isa()) { + if (attr.getName() == BufferizableOpInterface::kInplaceableAttrName) { + if (!attr.getValue().isa()) { return op->emitError() << "'" << BufferizableOpInterface::kInplaceableAttrName << "' is expected to be a boolean attribute"; } if (!op->hasTrait()) - return op->emitError() << "expected " << attr.first + return op->emitError() << "expected " << attr.getName() << " to be used on function-like operations"; return success(); } - if (attr.first == BufferizableOpInterface::kBufferLayoutAttrName) { - if (!attr.second.isa()) { + if (attr.getName() == BufferizableOpInterface::kBufferLayoutAttrName) { + if (!attr.getValue().isa()) { return op->emitError() << "'" << BufferizableOpInterface::kBufferLayoutAttrName << "' is expected to be a affine map attribute"; } if (!op->hasTrait()) - return op->emitError() << "expected " << attr.first + return op->emitError() << "expected " << attr.getName() << " to be used on function-like operations"; return success(); } - if (attr.first == LinalgDialect::kMemoizedIndexingMapsAttrName) + if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName) return success(); - return op->emitError() << "attribute '" << attr.first + return op->emitError() << "attribute '" << attr.getName() << "' not supported by the linalg dialect"; } diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp @@ -1211,8 +1211,8 @@ LogicalResult SPIRVDialect::verifyOperationAttribute(Operation *op, NamedAttribute attribute) { - StringRef symbol = attribute.first.strref(); - Attribute attr = attribute.second; + StringRef symbol = attribute.getName().strref(); + Attribute attr = attribute.getValue(); // TODO: figure out a way to generate the description from the // StructAttr definition. @@ -1237,8 +1237,8 @@ /// `valueType` is valid. static LogicalResult verifyRegionAttribute(Location loc, Type valueType, NamedAttribute attribute) { - StringRef symbol = attribute.first.strref(); - Attribute attr = attribute.second; + StringRef symbol = attribute.getName().strref(); + Attribute attr = attribute.getValue(); if (symbol != spirv::getInterfaceVarABIAttrName()) return emitError(loc, "found unsupported '") diff --git a/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp b/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp --- a/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp +++ b/mlir/lib/Dialect/SPIRV/Linking/ModuleCombiner/ModuleCombiner.cpp @@ -76,7 +76,7 @@ static llvm::hash_code computeHash(SymbolOpInterface symbolOp) { auto range = llvm::make_filter_range(symbolOp->getAttrs(), [](NamedAttribute attr) { - return attr.first != SymbolTable::getSymbolAttrName(); + return attr.getName() != SymbolTable::getSymbolAttrName(); }); return llvm::hash_combine( diff --git a/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp @@ -44,9 +44,8 @@ // Save all named attributes except "type" attribute. for (const auto &attr : op->getAttrs()) { - if (attr.first == "type") { + if (attr.getName() == "type") continue; - } globalVarAttrs.push_back(attr); } diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -580,9 +580,9 @@ // Copy over all attributes other than the function name and type. for (const auto &namedAttr : funcOp->getAttrs()) { - if (namedAttr.first != function_like_impl::getTypeAttrName() && - namedAttr.first != SymbolTable::getSymbolAttrName()) - newFuncOp->setAttr(namedAttr.first, namedAttr.second); + if (namedAttr.getName() != function_like_impl::getTypeAttrName() && + namedAttr.getName() != SymbolTable::getSymbolAttrName()) + newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue()); } rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -188,12 +188,12 @@ LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op, NamedAttribute attribute) { // Verify shape.lib attribute. - if (attribute.first == "shape.lib") { + if (attribute.getName() == "shape.lib") { if (!op->hasTrait()) return op->emitError( "shape.lib attribute may only be on op implementing SymbolTable"); - if (auto symbolRef = attribute.second.dyn_cast()) { + if (auto symbolRef = attribute.getValue().dyn_cast()) { auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef); if (!symbol) return op->emitError("shape function library ") @@ -204,7 +204,7 @@ << symbolRef << " required to be shape function library"; } - if (auto arr = attribute.second.dyn_cast()) { + if (auto arr = attribute.getValue().dyn_cast()) { // Verify all entries are function libraries and mappings in libraries // refer to unique ops. DenseSet key; @@ -219,10 +219,10 @@ return op->emitError() << it << " does not refer to FunctionLibraryOp"; for (auto mapping : shapeFnLib.getMapping()) { - if (!key.insert(mapping.first).second) { + if (!key.insert(mapping.getName()).second) { return op->emitError("only one op to shape mapping allowed, found " "multiple for `") - << mapping.first << "`"; + << mapping.getName() << "`"; } } } diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -54,8 +54,8 @@ unsigned ptr = 0; unsigned ind = 0; for (const NamedAttribute &attr : dict) { - if (attr.first == "dimLevelType") { - auto arrayAttr = attr.second.dyn_cast(); + if (attr.getName() == "dimLevelType") { + auto arrayAttr = attr.getValue().dyn_cast(); if (!arrayAttr) { parser.emitError(parser.getNameLoc(), "expected an array for dimension level types"); @@ -82,24 +82,24 @@ return {}; } } - } else if (attr.first == "dimOrdering") { - auto affineAttr = attr.second.dyn_cast(); + } else if (attr.getName() == "dimOrdering") { + auto affineAttr = attr.getValue().dyn_cast(); if (!affineAttr) { parser.emitError(parser.getNameLoc(), "expected an affine map for dimension ordering"); return {}; } map = affineAttr.getValue(); - } else if (attr.first == "pointerBitWidth") { - auto intAttr = attr.second.dyn_cast(); + } else if (attr.getName() == "pointerBitWidth") { + auto intAttr = attr.getValue().dyn_cast(); if (!intAttr) { parser.emitError(parser.getNameLoc(), "expected an integral pointer bitwidth"); return {}; } ptr = intAttr.getInt(); - } else if (attr.first == "indexBitWidth") { - auto intAttr = attr.second.dyn_cast(); + } else if (attr.getName() == "indexBitWidth") { + auto intAttr = attr.getValue().dyn_cast(); if (!intAttr) { parser.emitError(parser.getNameLoc(), "expected an integral index bitwidth"); @@ -108,7 +108,7 @@ ind = intAttr.getInt(); } else { parser.emitError(parser.getNameLoc(), "unexpected key: ") - << attr.first.str(); + << attr.getName().strref(); return {}; } } diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -486,7 +486,7 @@ traitAttrsSet.insert(attrNames.begin(), attrNames.end()); SmallVector attrs; for (auto attr : op->getAttrs()) - if (traitAttrsSet.count(attr.first.strref()) > 0) + if (traitAttrsSet.count(attr.getName().strref()) > 0) attrs.push_back(attr); auto dictAttr = DictionaryAttr::get(op.getContext(), attrs); diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -411,7 +411,7 @@ // Consider the attributes of the operation for aliases. for (const NamedAttribute &attr : op->getAttrs()) - printAttribute(attr.second); + printAttribute(attr.getValue()); } /// Print the given block. If 'printBlockArgs' is false, the arguments of the @@ -483,14 +483,14 @@ return; if (elidedAttrs.empty()) { for (const NamedAttribute &attr : attrs) - printAttribute(attr.second); + printAttribute(attr.getValue()); return; } llvm::SmallDenseSet elidedAttrsSet(elidedAttrs.begin(), elidedAttrs.end()); for (const NamedAttribute &attr : attrs) - if (!elidedAttrsSet.contains(attr.first.strref())) - printAttribute(attr.second); + if (!elidedAttrsSet.contains(attr.getName().strref())) + printAttribute(attr.getValue()); } void printOptionalAttrDictWithKeyword( ArrayRef attrs, @@ -2031,24 +2031,22 @@ llvm::SmallDenseSet elidedAttrsSet(elidedAttrs.begin(), elidedAttrs.end()); auto filteredAttrs = llvm::make_filter_range(attrs, [&](NamedAttribute attr) { - return !elidedAttrsSet.contains(attr.first.strref()); + return !elidedAttrsSet.contains(attr.getName().strref()); }); if (!filteredAttrs.empty()) printFilteredAttributesFn(filteredAttrs); } void AsmPrinter::Impl::printNamedAttribute(NamedAttribute attr) { - assert(attr.first.size() != 0 && "expected valid named attribute"); - // Print the name without quotes if possible. - ::printKeywordOrString(attr.first.strref(), os); + ::printKeywordOrString(attr.getName().strref(), os); // Pretty printing elides the attribute value for unit attributes. - if (attr.second.isa()) + if (attr.getValue().isa()) return; os << " = "; - printAttribute(attr.second); + printAttribute(attr.getValue()); } void AsmPrinter::Impl::printDialectAttribute(Attribute attr) { diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -23,9 +23,27 @@ // NamedAttribute //===----------------------------------------------------------------------===// -bool mlir::operator<(const NamedAttribute &lhs, const NamedAttribute &rhs) { - return lhs.first.compare(rhs.first) < 0; +NamedAttribute::NamedAttribute(StringAttr name, Attribute value) + : name(name), value(value) { + assert(name && value && "expected valid attribute name and value"); + assert(name.size() != 0 && "expected valid attribute name"); } -bool mlir::operator<(const NamedAttribute &lhs, StringRef rhs) { - return lhs.first.getValue().compare(rhs) < 0; + +StringAttr NamedAttribute::getName() const { return name.cast(); } + +Dialect *NamedAttribute::getNameDialect() const { + return getName().getReferencedDialect(); +} + +void NamedAttribute::setName(StringAttr newName) { + assert(name && "expected valid attribute name"); + name = newName; +} + +bool NamedAttribute::operator<(const NamedAttribute &rhs) const { + return getName().compare(rhs.getName()) < 0; +} + +bool NamedAttribute::operator<(StringRef rhs) const { + return getName().getValue().compare(rhs) < 0; } diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -119,11 +119,12 @@ return none; if (value.size() == 2) - return value[0].first == value[1].first ? value[0] : none; + return value[0].getName() == value[1].getName() ? value[0] : none; - auto it = std::adjacent_find( - value.begin(), value.end(), - [](NamedAttribute l, NamedAttribute r) { return l.first == r.first; }); + auto it = std::adjacent_find(value.begin(), value.end(), + [](NamedAttribute l, NamedAttribute r) { + return l.getName() == r.getName(); + }); return it != value.end() ? *it : none; } @@ -154,9 +155,6 @@ ArrayRef value) { if (value.empty()) return DictionaryAttr::getEmpty(context); - assert(llvm::all_of(value, - [](const NamedAttribute &attr) { return attr.second; }) && - "value cannot have null entries"); // We need to sort the element list to canonicalize it. SmallVector storage; @@ -173,10 +171,8 @@ if (value.empty()) return DictionaryAttr::getEmpty(context); // Ensure that the attribute elements are unique and sorted. - assert(llvm::is_sorted(value, - [](NamedAttribute l, NamedAttribute r) { - return l.first.strref() < r.first.strref(); - }) && + assert(llvm::is_sorted( + value, [](NamedAttribute l, NamedAttribute r) { return l < r; }) && "expected attribute values to be sorted"); assert(!findDuplicateElement(value) && "DictionaryAttr element names must be unique"); @@ -186,11 +182,11 @@ /// Return the specified attribute if present, null otherwise. Attribute DictionaryAttr::get(StringRef name) const { auto it = impl::findAttrSorted(begin(), end(), name); - return it.second ? it.first->second : Attribute(); + return it.second ? it.first->getValue() : Attribute(); } Attribute DictionaryAttr::get(StringAttr name) const { auto it = impl::findAttrSorted(begin(), end(), name); - return it.second ? it.first->second : Attribute(); + return it.second ? it.first->getValue() : Attribute(); } /// Return the specified named attribute if present, None otherwise. @@ -226,16 +222,16 @@ void DictionaryAttr::walkImmediateSubElements( function_ref walkAttrsFn, function_ref walkTypesFn) const { - for (Attribute attr : llvm::make_second_range(getValue())) - walkAttrsFn(attr); + for (const NamedAttribute &attr : getValue()) + walkAttrsFn(attr.getValue()); } SubElementAttrInterface DictionaryAttr::replaceImmediateSubAttribute( ArrayRef> replacements) const { std::vector vec = getValue().vec(); - for (auto &it : replacements) { - vec[it.first].second = it.second; - } + for (auto &it : replacements) + vec[it.first].setValue(it.second); + // The above only modifies the mapped value, but not the key, and therefore // not the order of the elements. It remains sorted return getWithSorted(getContext(), vec); diff --git a/mlir/lib/IR/BuiltinDialect.cpp b/mlir/lib/IR/BuiltinDialect.cpp --- a/mlir/lib/IR/BuiltinDialect.cpp +++ b/mlir/lib/IR/BuiltinDialect.cpp @@ -153,12 +153,17 @@ /// from this function to dest. void FuncOp::cloneInto(FuncOp dest, BlockAndValueMapping &mapper) { // Add the attributes of this function to dest. - llvm::MapVector newAttrs; + llvm::MapVector newAttrMap; for (const auto &attr : dest->getAttrs()) - newAttrs.insert(attr); + newAttrMap.insert({attr.getName(), attr.getValue()}); for (const auto &attr : (*this)->getAttrs()) - newAttrs.insert(attr); - dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs.takeVector())); + newAttrMap.insert({attr.getName(), attr.getValue()}); + + auto newAttrs = llvm::to_vector(llvm::map_range( + newAttrMap, [](std::pair attrPair) { + return NamedAttribute(attrPair.first, attrPair.second); + })); + dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs)); // Clone the body. getBody().cloneInto(&dest.getBody(), mapper); @@ -235,10 +240,9 @@ // Take the first and only (if present) attribute that implements the // interface. This needs a linear search, but is called only once per data // layout object construction that is used for repeated queries. - for (Attribute attr : llvm::make_second_range(getOperation()->getAttrs())) { - if (auto spec = attr.dyn_cast()) + for (NamedAttribute attr : getOperation()->getAttrs()) + if (auto spec = attr.getValue().dyn_cast()) return spec; - } return {}; } @@ -246,30 +250,30 @@ // Check that none of the attributes are non-dialect attributes, except for // the symbol related attributes. for (auto attr : op->getAttrs()) { - if (!attr.first.strref().contains('.') && + if (!attr.getName().strref().contains('.') && !llvm::is_contained( ArrayRef{mlir::SymbolTable::getSymbolAttrName(), mlir::SymbolTable::getVisibilityAttrName()}, - attr.first.strref())) + attr.getName().strref())) return op.emitOpError() << "can only contain attributes with " "dialect-prefixed names, found: '" - << attr.first.getValue() << "'"; + << attr.getName().getValue() << "'"; } // Check that there is at most one data layout spec attribute. StringRef layoutSpecAttrName; DataLayoutSpecInterface layoutSpec; for (const NamedAttribute &na : op->getAttrs()) { - if (auto spec = na.second.dyn_cast()) { + if (auto spec = na.getValue().dyn_cast()) { if (layoutSpec) { InFlightDiagnostic diag = op.emitOpError() << "expects at most one data layout attribute"; diag.attachNote() << "'" << layoutSpecAttrName << "' is a data layout attribute"; - diag.attachNote() << "'" << na.first.getValue() + diag.attachNote() << "'" << na.getName().getValue() << "' is a data layout attribute"; } - layoutSpecAttrName = na.first.strref(); + layoutSpecAttrName = na.getName().strref(); layoutSpec = spec; } } diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -72,11 +72,8 @@ } void NamedAttrList::push_back(NamedAttribute newAttribute) { - assert(newAttribute.second && "unexpected null attribute"); - if (isSorted()) { - dictionarySorted.setInt(attrs.empty() || - attrs.back().first.compare(newAttribute.first) < 0); - } + if (isSorted()) + dictionarySorted.setInt(attrs.empty() || attrs.back() < newAttribute); dictionarySorted.setPointer(nullptr); attrs.push_back(newAttribute); } @@ -84,11 +81,11 @@ /// Return the specified attribute if present, null otherwise. Attribute NamedAttrList::get(StringRef name) const { auto it = findAttr(*this, name); - return it.second ? it.first->second : Attribute(); + return it.second ? it.first->getValue() : Attribute(); } Attribute NamedAttrList::get(StringAttr name) const { auto it = findAttr(*this, name); - return it.second ? it.first->second : Attribute(); + return it.second ? it.first->getValue() : Attribute(); } /// Return the specified named attribute if present, None otherwise. @@ -112,12 +109,14 @@ if (it.second) { // Update the existing attribute by swapping out the old value for the new // value. Return the old value. - if (it.first->second != value) { - std::swap(it.first->second, value); + Attribute oldValue = it.first->getValue(); + if (it.first->getValue() != value) { + it.first->setValue(value); + // If the attributes have changed, the dictionary is invalidated. dictionarySorted.setPointer(nullptr); } - return value; + return oldValue; } // Perform a string lookup to insert the new attribute into its sorted // position. @@ -137,7 +136,7 @@ Attribute NamedAttrList::eraseImpl(SmallVectorImpl::iterator it) { // Erasing does not affect the sorted property. - Attribute attr = it->second; + Attribute attr = it->getValue(); attrs.erase(it); dictionarySorted.setPointer(nullptr); return attr; @@ -485,11 +484,12 @@ // Update any of the provided segment attributes. for (OperandSegment &segment : operandSegments) { - auto attr = segment.second.second.cast(); + auto attr = segment.second.getValue().cast(); SmallVector segments(attr.getValues()); segments[segment.first] += diff; - segment.second.second = DenseIntElementsAttr::get(attr.getType(), segments); - owner->setAttr(segment.second.first, segment.second.second); + segment.second.setValue( + DenseIntElementsAttr::get(attr.getType(), segments)); + owner->setAttr(segment.second.getName(), segment.second.getValue()); } } @@ -500,21 +500,21 @@ const MutableOperandRange &operands, NamedAttribute operandSegmentAttr) : MutableOperandRangeRange( OwnerT(operands, operandSegmentAttr), 0, - operandSegmentAttr.second.cast().size()) {} + operandSegmentAttr.getValue().cast().size()) {} MutableOperandRange MutableOperandRangeRange::join() const { return getBase().first; } MutableOperandRangeRange::operator OperandRangeRange() const { - return OperandRangeRange(getBase().first, - getBase().second.second.cast()); + return OperandRangeRange( + getBase().first, getBase().second.getValue().cast()); } MutableOperandRange MutableOperandRangeRange::dereference(const OwnerT &object, ptrdiff_t index) { auto sizeData = - object.second.second.cast().getValues(); + object.second.getValue().cast().getValues(); uint32_t startIndex = std::accumulate(sizeData.begin(), sizeData.begin() + index, 0); return object.first.slice( diff --git a/mlir/lib/IR/Verifier.cpp b/mlir/lib/IR/Verifier.cpp --- a/mlir/lib/IR/Verifier.cpp +++ b/mlir/lib/IR/Verifier.cpp @@ -170,7 +170,7 @@ /// Verify that all of the attributes are okay. for (auto attr : op.getAttrs()) { // Check for any optional dialect specific attributes. - if (auto *dialect = attr.first.getReferencedDialect()) + if (auto *dialect = attr.getNameDialect()) if (failed(dialect->verifyOperationAttribute(&op, attr))) return failure(); } diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -1123,7 +1123,7 @@ Optional duplicate = opState.attributes.findDuplicate(); if (duplicate) return emitError(getNameLoc(), "attribute '") - << duplicate->first.getValue() + << duplicate->getName().getValue() << "' occurs more than once in the attribute list"; return success(); } diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -812,7 +812,7 @@ // Insert comma in between operands and non-filtered attributes if needed. if (op.getNumOperands() > 0) { for (NamedAttribute attr : op.getAttrs()) { - if (!llvm::is_contained(exclude, attr.first.strref())) { + if (!llvm::is_contained(exclude, attr.getName().strref())) { os << ", "; break; } @@ -820,10 +820,10 @@ } // Emit attributes. auto emitNamedAttribute = [&](NamedAttribute attr) -> LogicalResult { - if (llvm::is_contained(exclude, attr.first.strref())) + if (llvm::is_contained(exclude, attr.getName().strref())) return success(); - os << "/* " << attr.first.getValue() << " */"; - if (failed(emitAttribute(op.getLoc(), attr.second))) + os << "/* " << attr.getName().getValue() << " */"; + if (failed(emitAttribute(op.getLoc(), attr.getValue()))) return failure(); return success(); }; diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -224,9 +224,9 @@ SmallVector parallelAccess; parallelAccess.push_back( llvm::MDString::get(ctx, "llvm.loop.parallel_accesses")); - for (SymbolRefAttr accessGroupRef : - parallelAccessGroup->second.cast() - .getAsRange()) + for (SymbolRefAttr accessGroupRef : parallelAccessGroup->getValue() + .cast() + .getAsRange()) parallelAccess.push_back( moduleTranslation.getAccessGroup(opInst, accessGroupRef)); loopOptions.push_back(llvm::MDNode::get(ctx, parallelAccess)); diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp @@ -57,7 +57,7 @@ LogicalResult amendOperation(Operation *op, NamedAttribute attribute, LLVM::ModuleTranslation &moduleTranslation) const final { - if (attribute.first == NVVM::NVVMDialect::getKernelFuncAttrName()) { + if (attribute.getName() == NVVM::NVVMDialect::getKernelFuncAttrName()) { auto func = dyn_cast(op); if (!func) return failure(); diff --git a/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp --- a/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp @@ -64,7 +64,7 @@ LogicalResult amendOperation(Operation *op, NamedAttribute attribute, LLVM::ModuleTranslation &moduleTranslation) const final { - if (attribute.first == ROCDL::ROCDLDialect::getKernelFuncAttrName()) { + if (attribute.getName() == ROCDL::ROCDLDialect::getKernelFuncAttrName()) { auto func = dyn_cast(op); if (!func) return failure(); diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -521,7 +521,7 @@ defaultValue); if (decorations.count(resultID)) { for (auto attr : decorations[resultID].getAttrs()) - op->setAttr(attr.first, attr.second); + op->setAttr(attr.getName(), attr.getValue()); } specConstMap[resultID] = op; return op; @@ -591,9 +591,8 @@ // Decorations. if (decorations.count(variableID)) { - for (auto attr : decorations[variableID].getAttrs()) { - varOp->setAttr(attr.first, attr.second); - } + for (auto attr : decorations[variableID].getAttrs()) + varOp->setAttr(attr.getName(), attr.getValue()); } globalVariableMap[variableID] = varOp; return success(); diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp --- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp @@ -295,8 +295,9 @@ (void)encodeInstructionInto(functionHeader, spirv::Opcode::OpVariable, operands); for (auto attr : op->getAttrs()) { - if (llvm::any_of(elidedAttrs, - [&](StringRef elided) { return attr.first == elided; })) { + if (llvm::any_of(elidedAttrs, [&](StringRef elided) { + return attr.getName() == elided; + })) { continue; } if (failed(processDecoration(op.getLoc(), resultID, attr))) { @@ -364,8 +365,9 @@ // Encode decorations. for (auto attr : varOp->getAttrs()) { - if (llvm::any_of(elidedAttrs, - [&](StringRef elided) { return attr.first == elided; })) { + if (llvm::any_of(elidedAttrs, [&](StringRef elided) { + return attr.getName() == elided; + })) { continue; } if (failed(processDecoration(varOp.getLoc(), resultID, attr))) { diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -205,7 +205,7 @@ LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID, NamedAttribute attr) { - auto attrName = attr.first.strref(); + auto attrName = attr.getName().strref(); auto decorationName = llvm::convertToCamelFromSnakeCase(attrName, true); auto decoration = spirv::symbolizeDecoration(decorationName); if (!decoration) { @@ -219,13 +219,13 @@ case spirv::Decoration::Binding: case spirv::Decoration::DescriptorSet: case spirv::Decoration::Location: - if (auto intAttr = attr.second.dyn_cast()) { + if (auto intAttr = attr.getValue().dyn_cast()) { args.push_back(intAttr.getValue().getZExtValue()); break; } return emitError(loc, "expected integer attribute for ") << attrName; case spirv::Decoration::BuiltIn: - if (auto strAttr = attr.second.dyn_cast()) { + if (auto strAttr = attr.getValue().dyn_cast()) { auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue()); if (enumVal) { args.push_back(static_cast(enumVal.getValue())); @@ -243,7 +243,7 @@ case spirv::Decoration::Restrict: case spirv::Decoration::RelaxedPrecision: // For unit attributes, the args list has no values so we do nothing - if (auto unitAttr = attr.second.dyn_cast()) + if (auto unitAttr = attr.getValue().dyn_cast()) break; return emitError(loc, "expected unit attribute for ") << attrName; default: diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -90,7 +90,7 @@ // Perform index rewrites for the dereferencing op and then replace the op NamedAttribute oldMapAttrPair = affMapAccInterface.getAffineMapAttrForMemRef(oldMemRef); - AffineMap oldMap = oldMapAttrPair.second.cast().getValue(); + AffineMap oldMap = oldMapAttrPair.getValue().cast().getValue(); unsigned oldMapNumInputs = oldMap.getNumInputs(); SmallVector oldMapOperands( op->operand_begin() + memRefOperandPos + 1, @@ -194,8 +194,8 @@ // Add attribute for 'newMap', other Attributes do not change. auto newMapAttr = AffineMapAttr::get(newMap); for (auto namedAttr : op->getAttrs()) { - if (namedAttr.first == oldMapAttrPair.first) - state.attributes.push_back({namedAttr.first, newMapAttr}); + if (namedAttr.getName() == oldMapAttrPair.getName()) + state.attributes.push_back({namedAttr.getName(), newMapAttr}); else state.attributes.push_back(namedAttr); } diff --git a/mlir/lib/Transforms/ViewOpGraph.cpp b/mlir/lib/Transforms/ViewOpGraph.cpp --- a/mlir/lib/Transforms/ViewOpGraph.cpp +++ b/mlir/lib/Transforms/ViewOpGraph.cpp @@ -221,8 +221,8 @@ if (printAttrs) { os << "\n"; for (const NamedAttribute &attr : op->getAttrs()) { - os << '\n' << attr.first.getValue() << ": "; - emitMlirAttr(os, attr.second); + os << '\n' << attr.getName().getValue() << ": "; + emitMlirAttr(os, attr.getValue()); } } }); diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -288,7 +288,7 @@ LogicalResult TestDialect::verifyOperationAttribute(Operation *op, NamedAttribute namedAttr) { - if (namedAttr.first == "test.invalid_attr") + if (namedAttr.getName() == "test.invalid_attr") return op->emitError() << "invalid to use 'test.invalid_attr'"; return success(); } @@ -297,7 +297,7 @@ unsigned regionIndex, unsigned argIndex, NamedAttribute namedAttr) { - if (namedAttr.first == "test.invalid_attr") + if (namedAttr.getName() == "test.invalid_attr") return op->emitError() << "invalid to use 'test.invalid_attr'"; return success(); } @@ -306,7 +306,7 @@ TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex, unsigned resultIndex, NamedAttribute namedAttr) { - if (namedAttr.first == "test.invalid_attr") + if (namedAttr.getName() == "test.invalid_attr") return op->emitError() << "invalid to use 'test.invalid_attr'"; return success(); } @@ -942,7 +942,7 @@ // If the attribute dictionary contains no 'names' attribute, infer it from // the SSA name (if specified). bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { - return attr.first == "names"; + return attr.getName() == "names"; }); // If there was no name specified, check to see if there was a useful name diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -243,7 +243,7 @@ if (!dAttr) return; for (auto d : dAttr) - dOp.emitRemark() << d.first.getValue() << " = " << d.second; + dOp.emitRemark() << d.getName().getValue() << " = " << d.getValue(); }); } diff --git a/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp b/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp --- a/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp +++ b/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp @@ -23,7 +23,7 @@ void runOnOperation() override { getOperation().walk([&](Operation *op) { for (NamedAttribute attr : op->getAttrs()) { - auto elementsAttr = attr.second.dyn_cast(); + auto elementsAttr = attr.getValue().dyn_cast(); if (!elementsAttr) continue; testElementsAttrIteration(op, elementsAttr, "uint64_t"); diff --git a/mlir/test/lib/IR/TestPrintNesting.cpp b/mlir/test/lib/IR/TestPrintNesting.cpp --- a/mlir/test/lib/IR/TestPrintNesting.cpp +++ b/mlir/test/lib/IR/TestPrintNesting.cpp @@ -37,8 +37,8 @@ if (!op->getAttrs().empty()) { printIndent() << op->getAttrs().size() << " attributes:\n"; for (NamedAttribute attr : op->getAttrs()) - printIndent() << " - '" << attr.first.getValue() << "' : '" - << attr.second << "'\n"; + printIndent() << " - '" << attr.getName().getValue() << "' : '" + << attr.getValue() << "'\n"; } // Recurse into each of the regions attached to the operation. diff --git a/mlir/test/mlir-tblgen/op-result.td b/mlir/test/mlir-tblgen/op-result.td --- a/mlir/test/mlir-tblgen/op-result.td +++ b/mlir/test/mlir-tblgen/op-result.td @@ -51,7 +51,7 @@ // CHECK-LABEL: OpD definitions // CHECK: void OpD::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) -// CHECK: odsState.addTypes({attr.second.cast<::mlir::TypeAttr>().getValue()}); +// CHECK: odsState.addTypes({attr.getValue().cast<::mlir::TypeAttr>().getValue()}); def OpE : NS_Op<"value_attr_as_result_type", [FirstAttrDerivedResultType]> { let arguments = (ins I32:$x, F32Attr:$attr); @@ -60,7 +60,7 @@ // CHECK-LABEL: OpE definitions // CHECK: void OpE::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) -// CHECK: odsState.addTypes({attr.second.getType()}); +// CHECK: odsState.addTypes({attr.getValue().getType()}); def OpF : NS_Op<"one_variadic_result_op", []> { let results = (outs Variadic:$x); diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -1449,11 +1449,11 @@ << "AttrName(" << builderOpState << ".name);\n" " for (auto attr : attributes) {\n" - " if (attr.first != attrName) continue;\n"; + " if (attr.getName() != attrName) continue;\n"; if (namedAttr.attr.isTypeAttr()) { - resultType = "attr.second.cast<::mlir::TypeAttr>().getValue()"; + resultType = "attr.getValue().cast<::mlir::TypeAttr>().getValue()"; } else { - resultType = "attr.second.getType()"; + resultType = "attr.getValue().getType()"; } // Operands diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp --- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp @@ -673,7 +673,8 @@ // All non-argument attributes translated into OpDecorate instruction os << tabs << formatv("for (auto attr : {0}->getAttrs()) {{\n", opVar); os << tabs - << formatv(" if (llvm::is_contained({0}, attr.first)) {{", elidedAttrs); + << formatv(" if (llvm::is_contained({0}, attr.getName())) {{", + elidedAttrs); os << tabs << " continue;\n"; os << tabs << " }\n"; os << tabs diff --git a/mlir/unittests/IR/OperationSupportTest.cpp b/mlir/unittests/IR/OperationSupportTest.cpp --- a/mlir/unittests/IR/OperationSupportTest.cpp +++ b/mlir/unittests/IR/OperationSupportTest.cpp @@ -237,11 +237,11 @@ { auto it = attrs.begin(); - EXPECT_EQ(it->first, b.getStringAttr("foo")); - EXPECT_EQ(it->second, b.getStringAttr("bar")); + EXPECT_EQ(it->getName(), b.getStringAttr("foo")); + EXPECT_EQ(it->getValue(), b.getStringAttr("bar")); ++it; - EXPECT_EQ(it->first, b.getStringAttr("baz")); - EXPECT_EQ(it->second, b.getStringAttr("boo")); + EXPECT_EQ(it->getName(), b.getStringAttr("baz")); + EXPECT_EQ(it->getValue(), b.getStringAttr("boo")); } attrs.append("foo", b.getStringAttr("zoo")); @@ -261,11 +261,11 @@ { auto it = attrs.begin(); - EXPECT_EQ(it->first, b.getStringAttr("foo")); - EXPECT_EQ(it->second, b.getStringAttr("f")); + EXPECT_EQ(it->getName(), b.getStringAttr("foo")); + EXPECT_EQ(it->getValue(), b.getStringAttr("f")); ++it; - EXPECT_EQ(it->first, b.getStringAttr("zoo")); - EXPECT_EQ(it->second, b.getStringAttr("z")); + EXPECT_EQ(it->getName(), b.getStringAttr("zoo")); + EXPECT_EQ(it->getValue(), b.getStringAttr("z")); } attrs.assign({}); diff --git a/mlir/unittests/TableGen/OpBuildGen.cpp b/mlir/unittests/TableGen/OpBuildGen.cpp --- a/mlir/unittests/TableGen/OpBuildGen.cpp +++ b/mlir/unittests/TableGen/OpBuildGen.cpp @@ -62,7 +62,8 @@ EXPECT_EQ(op->getAttrs().size(), attrs.size()); for (unsigned idx : llvm::seq(0U, attrs.size())) - EXPECT_EQ(op->getAttr(attrs[idx].first.strref()), attrs[idx].second); + EXPECT_EQ(op->getAttr(attrs[idx].getName().strref()), + attrs[idx].getValue()); concreteOp.erase(); } diff --git a/mlir/unittests/TableGen/StructsGenTest.cpp b/mlir/unittests/TableGen/StructsGenTest.cpp --- a/mlir/unittests/TableGen/StructsGenTest.cpp +++ b/mlir/unittests/TableGen/StructsGenTest.cpp @@ -62,7 +62,7 @@ // Add an extra NamedAttribute. auto wrongId = mlir::StringAttr::get(&context, "wrong"); - auto wrongAttr = mlir::NamedAttribute(wrongId, expectedValues[0].second); + auto wrongAttr = mlir::NamedAttribute(wrongId, expectedValues[0].getValue()); newValues.push_back(wrongAttr); // Make a new DictionaryAttr and validate. @@ -84,7 +84,7 @@ // Add a copy of the first attribute with the wrong name. auto wrongId = mlir::StringAttr::get(&context, "wrong"); - auto wrongAttr = mlir::NamedAttribute(wrongId, expectedValues[0].second); + auto wrongAttr = mlir::NamedAttribute(wrongId, expectedValues[0].getValue()); newValues.push_back(wrongAttr); auto badDictionary = mlir::DictionaryAttr::get(&context, newValues); @@ -108,7 +108,7 @@ auto elementsType = mlir::RankedTensorType::get({3}, i64Type); auto elementsAttr = mlir::DenseIntElementsAttr::get(elementsType, ArrayRef{1, 2, 3}); - mlir::StringAttr id = expectedValues.back().first; + mlir::StringAttr id = expectedValues.back().getName(); auto wrongAttr = mlir::NamedAttribute(id, elementsAttr); newValues.push_back(wrongAttr);