diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -654,7 +654,7 @@ mlir::Attribute ivalue; // Integer or Unit mlir::Block *dest; llvm::SmallVector destArg; - llvm::SmallVector temp; + mlir::NamedAttrList temp; if (parser.parseAttribute(ivalue, "i", temp) || parser.parseComma() || parser.parseSuccessorAndUseList(dest, destArg)) @@ -2215,7 +2215,7 @@ let parser = [{ auto &builder = parser.getBuilder(); mlir::Attribute val; - llvm::SmallVector attrs; + mlir::NamedAttrList attrs; if (parser.parseAttribute(val, "fake", attrs)) return mlir::failure(); if (auto v = val.dyn_cast()) @@ -2858,8 +2858,8 @@ return failure(); // Convert the parsed name attr into a string attr. - result.attributes.back().second = - parser.getBuilder().getStringAttr(nameAttr.getRootReference()); + result.attributes.set(mlir::SymbolTable::getSymbolAttrName(), + parser.getBuilder().getStringAttr(nameAttr.getRootReference())); // Parse the optional table body. mlir::Region *body = result.addRegion(); diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp --- a/flang/lib/Optimizer/Dialect/FIROps.cpp +++ b/flang/lib/Optimizer/Dialect/FIROps.cpp @@ -181,7 +181,7 @@ if (parser.parseOperandList(operands)) return mlir::failure(); - llvm::SmallVector attrs; + mlir::NamedAttrList attrs; mlir::SymbolRefAttr funcAttr; bool isDirect = operands.empty(); if (isDirect) @@ -259,7 +259,7 @@ static mlir::ParseResult parseCmpOp(mlir::OpAsmParser &parser, mlir::OperationState &result) { llvm::SmallVector ops; - llvm::SmallVector attrs; + mlir::NamedAttrList attrs; mlir::Attribute predicateNameAttr; mlir::Type type; if (parser.parseAttribute(predicateNameAttr, OPTY::getPredicateAttrName(), @@ -279,7 +279,8 @@ auto predicate = fir::CmpfOp::getPredicateByName(predicateName); auto builder = parser.getBuilder(); mlir::Type i1Type = builder.getI1Type(); - attrs[0].second = builder.getI64IntegerAttr(static_cast(predicate)); + attrs.set(OPTY::getPredicateAttrName(), + builder.getI64IntegerAttr(static_cast(predicate))); result.attributes = attrs; result.addTypes({i1Type}); return success(); @@ -1102,7 +1103,7 @@ mlir::Attribute attr; mlir::Block *dest; llvm::SmallVector destArg; - llvm::SmallVector temp; + mlir::NamedAttrList temp; if (parser.parseAttribute(attr, "a", temp) || isValidCaseAttr(attr) || parser.parseComma()) return mlir::failure(); @@ -1323,7 +1324,7 @@ mlir::Attribute attr; mlir::Block *dest; llvm::SmallVector destArg; - llvm::SmallVector temp; + mlir::NamedAttrList temp; if (parser.parseAttribute(attr, "a", temp) || parser.parseComma() || parser.parseSuccessorAndUseList(dest, destArg)) return mlir::failure(); diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -268,6 +268,9 @@ /// be non-null. using NamedAttribute = std::pair; +bool operator<(const NamedAttribute &lhs, const NamedAttribute &rhs); +bool operator<(const NamedAttribute &lhs, StringRef rhs); + /// Dictionary attribute is an attribute that represents a sorted collection of /// named attribute values. The elements are sorted by name, and each name must /// be unique within the collection. @@ -308,14 +311,25 @@ size_t size() const; /// Sorts the NamedAttributes in the array ordered by name as expected by - /// getWithSorted. + /// getWithSorted and returns whether the values were sorted. /// Requires: uniquely named attributes. - static void sort(SmallVectorImpl &array); + static bool sort(ArrayRef values, + SmallVectorImpl &storage); + + /// Sorts the NamedAttributes in the array ordered by name as expected by + /// getWithSorted in place on an array and returns whether the values needed + /// to be sorted. + /// Requires: uniquely named attributes. + static bool sortInPlace(SmallVectorImpl &array); /// Methods for supporting type inquiry through isa, cast, and dyn_cast. static bool kindof(unsigned kind) { return kind == StandardAttributes::Dictionary; } + +private: + /// Return empty context. + static DictionaryAttr getEmpty(MLIRContext *context); }; //===----------------------------------------------------------------------===// @@ -1504,9 +1518,8 @@ return attrs == other.attrs; } - /// Return the underlying dictionary attribute. This may be null, if this list - /// has no attributes. - DictionaryAttr getDictionary() const { return attrs; } + /// Return the underlying dictionary attribute. + DictionaryAttr getDictionary(MLIRContext *context) const; /// Return all of the attributes on this operation. ArrayRef getAttrs() const; @@ -1532,10 +1545,20 @@ /// value indicates whether the attribute was present or not. RemoveResult remove(Identifier name); + bool empty() const { return attrs == nullptr; } + private: + friend ::llvm::hash_code hash_value(const MutableDictionaryAttr &arg); + DictionaryAttr attrs; }; +inline ::llvm::hash_code hash_value(const MutableDictionaryAttr &arg) { + if (!arg.attrs || arg.attrs.empty()) + return ::llvm::hash_value((void *)nullptr); + return hash_value(arg.attrs.cast()); +} + } // end namespace mlir. namespace llvm { diff --git a/mlir/include/mlir/IR/FunctionImplementation.h b/mlir/include/mlir/IR/FunctionImplementation.h --- a/mlir/include/mlir/IR/FunctionImplementation.h +++ b/mlir/include/mlir/IR/FunctionImplementation.h @@ -38,8 +38,8 @@ /// Internally, argument and result attributes are stored as dict attributes /// with special names given by getResultAttrName, getArgumentAttrName. void addArgAndResultAttrs(Builder &builder, OperationState &result, - ArrayRef> argAttrs, - ArrayRef> resultAttrs); + ArrayRef argAttrs, + ArrayRef resultAttrs); /// Callback type for `parseFunctionLikeOp`, the callback should produce the /// type that will be associated with a function-like operation from lists of @@ -53,13 +53,13 @@ /// indicates whether functions with variadic arguments are supported. The /// trailing arguments are populated by this function with names, types and /// attributes of the arguments and those of the results. -ParseResult parseFunctionSignature( - OpAsmParser &parser, bool allowVariadic, - SmallVectorImpl &argNames, - SmallVectorImpl &argTypes, - SmallVectorImpl> &argAttrs, bool &isVariadic, - SmallVectorImpl &resultTypes, - SmallVectorImpl> &resultAttrs); +ParseResult +parseFunctionSignature(OpAsmParser &parser, bool allowVariadic, + SmallVectorImpl &argNames, + SmallVectorImpl &argTypes, + SmallVectorImpl &argAttrs, + bool &isVariadic, SmallVectorImpl &resultTypes, + SmallVectorImpl &resultAttrs); /// Parser implementation for function-like operations. Uses /// `funcTypeBuilder` to construct the custom function type given lists of diff --git a/mlir/include/mlir/IR/FunctionSupport.h b/mlir/include/mlir/IR/FunctionSupport.h --- a/mlir/include/mlir/IR/FunctionSupport.h +++ b/mlir/include/mlir/IR/FunctionSupport.h @@ -517,10 +517,14 @@ MutableDictionaryAttr attributes) { assert(index < getNumArguments() && "invalid argument number"); SmallString<8> nameOut; - if (auto newAttr = attributes.getDictionary()) + if (attributes.getAttrs().empty()) { + this->getOperation()->removeAttr(getArgAttrName(index, nameOut)); + } else { + auto newAttr = attributes.getDictionary( + attributes.getAttrs().front().second.getContext()); return this->getOperation()->setAttr(getArgAttrName(index, nameOut), newAttr); - static_cast(this)->removeAttr(getArgAttrName(index, nameOut)); + } } /// If the an attribute exists with the specified name, change it to the new @@ -533,7 +537,7 @@ attrDict.set(name, value); // If the attribute changed, then set the new arg attribute list. - if (curAttr != attrDict.getDictionary()) + if (curAttr != attrDict.getDictionary(value.getContext())) setArgAttrs(index, attrDict); } @@ -564,7 +568,7 @@ getResultAttrName(index, nameOut); if (attributes.empty()) - return (void)static_cast(this)->removeAttr(nameOut); + return (void)this->getOperation()->removeAttr(nameOut); Operation *op = this->getOperation(); op->setAttr(nameOut, DictionaryAttr::get(attributes, op->getContext())); } @@ -574,11 +578,14 @@ unsigned index, MutableDictionaryAttr attributes) { assert(index < getNumResults() && "invalid result number"); SmallString<8> nameOut; - if (auto newAttr = attributes.getDictionary()) + if (attributes.getAttrs().empty()) { + this->getOperation()->removeAttr(getResultAttrName(index, nameOut)); + } else { + auto newAttr = attributes.getDictionary( + attributes.getAttrs().front().second.getContext()); return this->getOperation()->setAttr(getResultAttrName(index, nameOut), newAttr); - static_cast(this)->removeAttr( - getResultAttrName(index, nameOut)); + } } /// If the an attribute exists with the specified name, change it to the new @@ -591,7 +598,7 @@ attrDict.set(name, value); // If the attribute changed, then set the new arg attribute list. - if (curAttr != attrDict.getDictionary()) + if (curAttr != attrDict.getDictionary(value.getContext())) setResultAttrs(index, attrDict); } diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -366,28 +366,28 @@ /// Parse an arbitrary attribute and return it in result. This also adds the /// attribute to the specified attribute list with the specified name. ParseResult parseAttribute(Attribute &result, StringRef attrName, - SmallVectorImpl &attrs) { + NamedAttrList &attrs) { return parseAttribute(result, Type(), attrName, attrs); } /// Parse an attribute of a specific kind and type. template ParseResult parseAttribute(AttrType &result, StringRef attrName, - SmallVectorImpl &attrs) { + NamedAttrList &attrs) { return parseAttribute(result, Type(), attrName, attrs); } /// Parse an arbitrary attribute of a given type and return it in result. This /// also adds the attribute to the specified attribute list with the specified /// name. - virtual ParseResult - parseAttribute(Attribute &result, Type type, StringRef attrName, - SmallVectorImpl &attrs) = 0; + virtual ParseResult parseAttribute(Attribute &result, Type type, + StringRef attrName, + NamedAttrList &attrs) = 0; /// Parse an attribute of a specific kind and type. template ParseResult parseAttribute(AttrType &result, Type type, StringRef attrName, - SmallVectorImpl &attrs) { + NamedAttrList &attrs) { llvm::SMLoc loc = getCurrentLocation(); // Parse any kind of attribute. @@ -404,13 +404,12 @@ } /// Parse a named dictionary into 'result' if it is present. - virtual ParseResult - parseOptionalAttrDict(SmallVectorImpl &result) = 0; + virtual ParseResult parseOptionalAttrDict(NamedAttrList &result) = 0; /// Parse a named dictionary into 'result' if the `attributes` keyword is /// present. virtual ParseResult - parseOptionalAttrDictWithKeyword(SmallVectorImpl &result) = 0; + parseOptionalAttrDictWithKeyword(NamedAttrList &result) = 0; /// Parse an affine map instance into 'map'. virtual ParseResult parseAffineMap(AffineMap &map) = 0; @@ -425,7 +424,7 @@ /// Parse an @-identifier and store it (without the '@' symbol) in a string /// attribute named 'attrName'. ParseResult parseSymbolName(StringAttr &result, StringRef attrName, - SmallVectorImpl &attrs) { + NamedAttrList &attrs) { if (failed(parseOptionalSymbolName(result, attrName, attrs))) return emitError(getCurrentLocation()) << "expected valid '@'-identifier for symbol name"; @@ -434,9 +433,9 @@ /// Parse an optional @-identifier and store it (without the '@' symbol) in a /// string attribute named 'attrName'. - virtual ParseResult - parseOptionalSymbolName(StringAttr &result, StringRef attrName, - SmallVectorImpl &attrs) = 0; + virtual ParseResult parseOptionalSymbolName(StringAttr &result, + StringRef attrName, + NamedAttrList &attrs) = 0; //===--------------------------------------------------------------------===// // Operand Parsing @@ -552,8 +551,7 @@ /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol. virtual ParseResult parseAffineMapOfSSAIds(SmallVectorImpl &operands, Attribute &map, - StringRef attrName, - SmallVectorImpl &attrs, + StringRef attrName, NamedAttrList &attrs, Delimiter delimiter = Delimiter::Square) = 0; //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -294,6 +294,11 @@ /// Return all of the attributes on this operation. ArrayRef getAttrs() { return attrs.getAttrs(); } + /// Return all of the attributes on this operation as a DictionaryAttr. + DictionaryAttr getAttrDictionary() { + return attrs.getDictionary(getContext()); + } + /// Return mutable container of all the attributes on this operation. MutableDictionaryAttr &getMutableAttrDict() { return attrs; } @@ -326,6 +331,9 @@ MutableDictionaryAttr::RemoveResult removeAttr(Identifier name) { return attrs.remove(name); } + MutableDictionaryAttr::RemoveResult removeAttr(StringRef name) { + return attrs.remove(Identifier::get(name, getContext())); + } /// A utility iterator that filters out non-dialect attributes. class dialect_attr_iterator diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -200,6 +200,100 @@ bool (&hasRawTrait)(TypeID traitID); }; +//===----------------------------------------------------------------------===// +// NamedAttrList +//===----------------------------------------------------------------------===// + +/// NamedAttrList is array of NamedAttributes that tracks whether it is sorted +/// and does some basic work to remain sorted. +class NamedAttrList { +public: + using const_iterator = SmallVectorImpl::const_iterator; + using const_reference = const NamedAttribute &; + using reference = NamedAttribute &; + using size_type = size_t; + + NamedAttrList() : dictionarySorted({}, true) {} + NamedAttrList(ArrayRef attributes); + NamedAttrList(const_iterator in_start, const_iterator in_end); + + bool operator!=(const NamedAttrList &other) const { + return !(*this == other); + } + bool operator==(const NamedAttrList &other) const { + return attrs == other.attrs; + } + + /// Add an attribute with the specified name. + void append(StringRef name, Attribute attr); + + /// Add an attribute with the specified name. + void append(Identifier name, Attribute attr); + + /// Add an array of named attributes. + void append(ArrayRef newAttributes); + + /// Add a range of named attributes. + void append(const_iterator in_start, const_iterator in_end); + + /// Replaces the attributes with new list of attributes. + void assign(const_iterator in_start, const_iterator in_end); + + /// Replaces the attributes with new list of attributes. + void assign(ArrayRef range) { + append(range.begin(), range.end()); + } + + bool empty() const { return attrs.empty(); } + + void reserve(size_type N) { attrs.reserve(N); } + + /// Add an attribute with the specified name. + void push_back(NamedAttribute newAttribute); + + /// Pop last element from list. + void pop_back() { attrs.pop_back(); } + + /// Return a dictionary attribute for the underlying dictionary. This will + /// return an empty dictionary attribute if empty rather than null. + DictionaryAttr getDictionary(MLIRContext *context) const; + + /// Return all of the attributes on this operation. + ArrayRef getAttrs() const; + + /// Return the specified attribute if present, null otherwise. + Attribute get(Identifier name) const; + Attribute get(StringRef name) const; + + /// Return the specified named attribute if present, None otherwise. + Optional getNamed(StringRef name) const; + Optional getNamed(Identifier name) const; + + /// If the an attribute exists with the specified name, change it to the new + /// value. Otherwise, add a new attribute with the specified name/value. + void set(Identifier name, Attribute value); + void set(StringRef name, Attribute value); + + const_iterator begin() const { return attrs.begin(); } + const_iterator end() const { return attrs.end(); } + + NamedAttrList &operator=(const SmallVectorImpl &rhs); + operator ArrayRef() const; + operator MutableDictionaryAttr() const; + +private: + /// Return whether the attributes are sorted. + bool isSorted() const { return dictionarySorted.getInt(); } + + // These are marked mutable as they may be modified (e.g., sorted) + mutable SmallVector attrs; + // Pair with cached DictionaryAttr and status of whether attrs is sorted. + // Note: just because sorted does not mean a DictionaryAttr has been created + // but the case where there is a DictionaryAttr but attrs isn't sorted should + // not occur. + mutable llvm::PointerIntPair dictionarySorted; +}; + //===----------------------------------------------------------------------===// // OperationName //===----------------------------------------------------------------------===// @@ -268,7 +362,7 @@ SmallVector operands; /// Types of the results of this operation. SmallVector types; - SmallVector attributes; + NamedAttrList attributes; /// Successors of this operation and their respective operands. SmallVector successors; /// Regions that the op will hold. @@ -302,12 +396,12 @@ /// Add an attribute with the specified name. void addAttribute(Identifier name, Attribute attr) { - attributes.push_back({name, attr}); + attributes.append(name, attr); } /// Add an array of named attributes. void addAttributes(ArrayRef newAttributes) { - attributes.append(newAttributes.begin(), newAttributes.end()); + attributes.append(newAttributes); } /// Add an array of successors. @@ -328,7 +422,7 @@ void addRegion(std::unique_ptr &®ion); /// Get the context held by this operation state. - MLIRContext *getContext() { return location->getContext(); } + MLIRContext *getContext() const { return location->getContext(); } }; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h --- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h +++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h @@ -83,11 +83,11 @@ LogicalResult inferReturnTensorTypes( function_ref location, ValueRange operands, - ArrayRef attributes, RegionRange regions, + DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &retComponents)> componentTypeFn, MLIRContext *context, Optional location, ValueRange operands, - ArrayRef attributes, RegionRange regions, + DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnTypes); /// Verifies that the inferred result types match the actual result types for @@ -107,7 +107,7 @@ public: static LogicalResult inferReturnTypes(MLIRContext *context, Optional location, - ValueRange operands, ArrayRef attributes, + ValueRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { return ::mlir::detail::inferReturnTensorTypes( diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td --- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td +++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td @@ -40,7 +40,7 @@ /*args=*/(ins "MLIRContext*":$context, "Optional":$location, "ValueRange":$operands, - "ArrayRef":$attributes, + "DictionaryAttr":$attributes, "RegionRange":$regions, "SmallVectorImpl&":$inferredReturnTypes) >, @@ -92,7 +92,7 @@ /*args=*/(ins "MLIRContext*":$context, "Optional":$location, "ValueRange":$operands, - "ArrayRef":$attributes, + "DictionaryAttr":$attributes, "RegionRange":$regions, "SmallVectorImpl&": $inferredReturnShapes) diff --git a/mlir/lib/Analysis/CallGraph.cpp b/mlir/lib/Analysis/CallGraph.cpp --- a/mlir/lib/Analysis/CallGraph.cpp +++ b/mlir/lib/Analysis/CallGraph.cpp @@ -179,7 +179,8 @@ auto *parentOp = callableRegion->getParentOp(); os << "'" << callableRegion->getParentOp()->getName() << "' - Region #" << callableRegion->getRegionNumber(); - if (auto attrs = parentOp->getMutableAttrDict().getDictionary()) + auto attrs = parentOp->getAttrDictionary(); + if (!attrs.empty()) os << " : " << attrs; }; diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -2453,7 +2453,7 @@ return failure(); AffineMapAttr stepsMapAttr; - SmallVector stepsAttrs; + NamedAttrList stepsAttrs; SmallVector stepsMapOperands; if (failed(parser.parseOptionalKeyword("step"))) { SmallVector steps(ivs.size(), 1); diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -538,8 +538,8 @@ /// function-attributes? region static ParseResult parseGPUFuncOp(OpAsmParser &parser, OperationState &result) { SmallVector entryArgs; - SmallVector, 1> argAttrs; - SmallVector, 1> resultAttrs; + SmallVector argAttrs; + SmallVector resultAttrs; SmallVector argTypes; SmallVector resultTypes; bool isVariadic; diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -92,8 +92,8 @@ predicateValue = static_cast(predicate.getValue()); } - result.attributes[0].second = - parser.getBuilder().getI64IntegerAttr(predicateValue); + result.attributes.set("predicate", + parser.getBuilder().getI64IntegerAttr(predicateValue)); // The result type is either i1 or a vector type if the inputs are // vectors. @@ -1186,7 +1186,7 @@ "expected as many argument attribute lists as arguments"); SmallString<8> argAttrName; for (unsigned i = 0; i < numInputs; ++i) - if (auto argDict = argAttrs[i].getDictionary()) + if (auto argDict = argAttrs[i].getDictionary(builder.getContext())) result.addAttribute(getArgAttrName(i, argAttrName), argDict); } @@ -1249,8 +1249,8 @@ StringAttr nameAttr; SmallVector entryArgs; - SmallVector, 1> argAttrs; - SmallVector, 1> resultAttrs; + SmallVector argAttrs; + SmallVector resultAttrs; SmallVector argTypes; SmallVector resultTypes; bool isVariadic; diff --git a/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp @@ -353,8 +353,8 @@ } bool isSameAttrList(spirv::StoreOp lhs, spirv::StoreOp rhs) const { - return lhs.getOperation()->getMutableAttrDict().getDictionary() == - rhs.getOperation()->getMutableAttrDict().getDictionary(); + return lhs.getOperation()->getAttrDictionary() == + rhs.getOperation()->getAttrDictionary(); } diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -105,7 +105,7 @@ parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, StringRef attrName = spirv::attributeName()) { Attribute attrVal; - SmallVector attr; + NamedAttrList attr; auto loc = parser.getCurrentLocation(); if (parser.parseAttribute(attrVal, parser.getBuilder().getNoneType(), attrName, attr)) { @@ -1019,7 +1019,7 @@ // Parse the optional branch weights. if (succeeded(parser.parseOptionalLSquare())) { IntegerAttr trueWeight, falseWeight; - SmallVector weights; + NamedAttrList weights; auto i32Type = builder.getIntegerType(32); if (parser.parseAttribute(trueWeight, i32Type, "weight", weights) || @@ -1443,7 +1443,7 @@ // The name of the interface variable attribute isnt important auto attrName = "var_symbol"; FlatSymbolRefAttr var; - SmallVector attrs; + NamedAttrList attrs; if (parser.parseAttribute(var, Type(), attrName, attrs)) { return failure(); } @@ -1497,7 +1497,7 @@ SmallVector values; Type i32Type = parser.getBuilder().getIntegerType(32); while (!parser.parseOptionalComma()) { - SmallVector attr; + NamedAttrList attr; Attribute value; if (parser.parseAttribute(value, i32Type, "value", attr)) { return failure(); @@ -1529,8 +1529,8 @@ static ParseResult parseFuncOp(OpAsmParser &parser, OperationState &state) { SmallVector entryArgs; - SmallVector, 4> argAttrs; - SmallVector, 4> resultAttrs; + SmallVector argAttrs; + SmallVector resultAttrs; SmallVector argTypes; SmallVector resultTypes; auto &builder = parser.getBuilder(); diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -92,10 +92,11 @@ // BroadcastOp //===----------------------------------------------------------------------===// -LogicalResult BroadcastOp::inferReturnTypes( - MLIRContext *context, Optional location, ValueRange operands, - ArrayRef attributes, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { +LogicalResult +BroadcastOp::inferReturnTypes(MLIRContext *context, Optional location, + ValueRange operands, DictionaryAttr attributes, + RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { inferredReturnTypes.push_back(ShapeType::get(context)); return success(); } @@ -137,7 +138,7 @@ // shape as an ArrayAttr. // TODO: Implement custom parser and maybe make syntax a bit more concise. Attribute extentsRaw; - SmallVector dummy; + NamedAttrList dummy; if (parser.parseAttribute(extentsRaw, "dummy", dummy)) return failure(); auto extentsArray = extentsRaw.dyn_cast(); @@ -159,10 +160,11 @@ OpFoldResult ConstShapeOp::fold(ArrayRef) { return shape(); } -LogicalResult ConstShapeOp::inferReturnTypes( - MLIRContext *context, Optional location, ValueRange operands, - ArrayRef attributes, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { +LogicalResult +ConstShapeOp::inferReturnTypes(MLIRContext *context, + Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { inferredReturnTypes.push_back(ShapeType::get(context)); return success(); } @@ -171,10 +173,11 @@ // ConstSizeOp //===----------------------------------------------------------------------===// -LogicalResult ConstSizeOp::inferReturnTypes( - MLIRContext *context, Optional location, ValueRange operands, - ArrayRef attributes, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { +LogicalResult +ConstSizeOp::inferReturnTypes(MLIRContext *context, Optional location, + ValueRange operands, DictionaryAttr attributes, + RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { inferredReturnTypes.push_back(SizeType::get(context)); return success(); } @@ -183,10 +186,11 @@ // ShapeOfOp //===----------------------------------------------------------------------===// -LogicalResult ShapeOfOp::inferReturnTypes( - MLIRContext *context, Optional location, ValueRange operands, - ArrayRef attributes, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { +LogicalResult +ShapeOfOp::inferReturnTypes(MLIRContext *context, Optional location, + ValueRange operands, DictionaryAttr attributes, + RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { inferredReturnTypes.push_back(ShapeType::get(context)); return success(); } @@ -203,10 +207,11 @@ // SplitAtOp //===----------------------------------------------------------------------===// -LogicalResult SplitAtOp::inferReturnTypes( - MLIRContext *context, Optional location, ValueRange operands, - ArrayRef attributes, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { +LogicalResult +SplitAtOp::inferReturnTypes(MLIRContext *context, Optional location, + ValueRange operands, DictionaryAttr attributes, + RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { auto shapeType = ShapeType::get(context); inferredReturnTypes.push_back(shapeType); inferredReturnTypes.push_back(shapeType); @@ -238,10 +243,11 @@ // ConcatOp //===----------------------------------------------------------------------===// -LogicalResult ConcatOp::inferReturnTypes( - MLIRContext *context, Optional location, ValueRange operands, - ArrayRef attributes, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { +LogicalResult +ConcatOp::inferReturnTypes(MLIRContext *context, Optional location, + ValueRange operands, DictionaryAttr attributes, + RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { auto shapeType = ShapeType::get(context); inferredReturnTypes.push_back(shapeType); return success(); diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -478,7 +478,7 @@ static ParseResult parseExtractOp(OpAsmParser &parser, OperationState &result) { llvm::SMLoc attributeLoc, typeLoc; - SmallVector attrs; + NamedAttrList attrs; OpAsmParser::OperandType vector; Type type; Attribute attr; diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -111,21 +111,30 @@ // Specialize for the common case. switch (value.size()) { case 0: + // Zero already sorted. + break; case 1: - // Zero or one elements are already sorted. + // One already sorted but may need to be copied. + if (!inPlace) + storage.assign({value[0]}); break; - case 2: + case 2: { assert(value[0].first != value[1].first && "DictionaryAttr element names must be unique"); - if (compareNamedAttributes(&value[0], &value[1]) > 0) { - if (inPlace) + bool isSorted = compareNamedAttributes(&value[0], &value[1]) < 0; + if (inPlace) { + if (!isSorted) std::swap(storage[0], storage[1]); - else - storage.append({value[1], value[0]}); - return true; + } else if (isSorted) { + storage.assign({value[0], value[1]}); + } else { + storage.assign({value[1], value[0]}); } - break; + return !isSorted; + } default: + if (!inPlace) + storage.assign(value.begin(), value.end()); // Check to see they are sorted already. bool isSorted = llvm::is_sorted(value, [](NamedAttribute l, NamedAttribute r) { @@ -133,8 +142,6 @@ }); if (!isSorted) { // If not, do a general sort. - if (!inPlace) - storage.append(value.begin(), value.end()); llvm::array_pod_sort(storage.begin(), storage.end(), compareNamedAttributes); value = storage; @@ -151,15 +158,19 @@ return false; } -/// Sorts the NamedAttributes in the array ordered by name as expected by -/// getWithSorted. -/// Requires: uniquely named attributes. -void DictionaryAttr::sort(SmallVectorImpl &array) { - dictionaryAttrSort(array, array); +bool DictionaryAttr::sort(ArrayRef value, + SmallVectorImpl &storage) { + return dictionaryAttrSort(value, storage); +} + +bool DictionaryAttr::sortInPlace(SmallVectorImpl &array) { + return dictionaryAttrSort(array, array); } DictionaryAttr DictionaryAttr::get(ArrayRef value, MLIRContext *context) { + if (value.empty()) + return DictionaryAttr::getEmpty(context); assert(llvm::all_of(value, [](const NamedAttribute &attr) { return attr.second; }) && "value cannot have null entries"); @@ -171,11 +182,12 @@ return Base::get(context, StandardAttributes::Dictionary, value); } - /// Construct a dictionary with an array of values that is known to already be /// sorted by name and uniqued. DictionaryAttr DictionaryAttr::getWithSorted(ArrayRef value, MLIRContext *context) { + if (value.empty()) + return DictionaryAttr::getEmpty(context); // Ensure that the attribute elements are unique and sorted. assert(llvm::is_sorted(value, [](NamedAttribute l, NamedAttribute r) { @@ -207,7 +219,8 @@ /// Return the specified named attribute if present, None otherwise. Optional DictionaryAttr::getNamed(StringRef name) const { ArrayRef values = getValue(); - auto it = llvm::lower_bound(values, name, compareNamedAttributeWithName); + const auto *it = + llvm::lower_bound(values, name, compareNamedAttributeWithName); return it != values.end() && it->first == name ? *it : Optional(); } @@ -1193,6 +1206,15 @@ setAttrs(attributes); } +/// Return the underlying dictionary attribute. +DictionaryAttr +MutableDictionaryAttr::getDictionary(MLIRContext *context) const { + // Construct empty DictionaryAttr if needed. + if (!attrs) + return DictionaryAttr::get({}, context); + return attrs; +} + ArrayRef MutableDictionaryAttr::getAttrs() const { return attrs ? attrs.getValue() : llvm::None; } @@ -1232,7 +1254,7 @@ // Look for an existing value for the given name, and set it in-place. ArrayRef values = getAttrs(); - auto it = llvm::find_if( + const auto *it = llvm::find_if( values, [name](NamedAttribute attr) { return attr.first == name; }); if (it != values.end()) { // Bail out early if the value is the same as what we already have. @@ -1278,3 +1300,10 @@ } return RemoveResult::NotFound; } + +bool mlir::operator<(const NamedAttribute &lhs, const NamedAttribute &rhs) { + return strncmp(lhs.first.data(), rhs.first.data(), rhs.first.size()) < 0; +} +bool mlir::operator<(const NamedAttribute &lhs, StringRef rhs) { + return strncmp(lhs.first.data(), rhs.data(), rhs.size()) < 0; +} diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp --- a/mlir/lib/IR/Function.cpp +++ b/mlir/lib/IR/Function.cpp @@ -57,7 +57,7 @@ assert(type.getNumInputs() == argAttrs.size()); SmallString<8> argAttrName; for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i) - if (auto argDict = argAttrs[i].getDictionary()) + if (auto argDict = argAttrs[i].getDictionary(builder.getContext())) result.addAttribute(getArgAttrName(i, argAttrName), argDict); } diff --git a/mlir/lib/IR/FunctionImplementation.cpp b/mlir/lib/IR/FunctionImplementation.cpp --- a/mlir/lib/IR/FunctionImplementation.cpp +++ b/mlir/lib/IR/FunctionImplementation.cpp @@ -17,8 +17,7 @@ parseArgumentList(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl &argTypes, SmallVectorImpl &argNames, - SmallVectorImpl> &argAttrs, - bool &isVariadic) { + SmallVectorImpl &argAttrs, bool &isVariadic) { if (parser.parseLParen()) return failure(); @@ -54,7 +53,7 @@ argTypes.push_back(argumentType); // Parse any argument attributes. - SmallVector attrs; + NamedAttrList attrs; if (parser.parseOptionalAttrDict(attrs)) return failure(); argAttrs.push_back(attrs); @@ -90,9 +89,9 @@ /// function-result-list-no-parens ::= function-result (`,` function-result)* /// function-result ::= type attribute-dict? /// -static ParseResult parseFunctionResultList( - OpAsmParser &parser, SmallVectorImpl &resultTypes, - SmallVectorImpl> &resultAttrs) { +static ParseResult +parseFunctionResultList(OpAsmParser &parser, SmallVectorImpl &resultTypes, + SmallVectorImpl &resultAttrs) { if (failed(parser.parseOptionalLParen())) { // We already know that there is no `(`, so parse a type. // Because there is no `(`, it cannot be a function type. @@ -127,10 +126,9 @@ ParseResult mlir::impl::parseFunctionSignature( OpAsmParser &parser, bool allowVariadic, SmallVectorImpl &argNames, - SmallVectorImpl &argTypes, - SmallVectorImpl> &argAttrs, bool &isVariadic, - SmallVectorImpl &resultTypes, - SmallVectorImpl> &resultAttrs) { + SmallVectorImpl &argTypes, SmallVectorImpl &argAttrs, + bool &isVariadic, SmallVectorImpl &resultTypes, + SmallVectorImpl &resultAttrs) { if (parseArgumentList(parser, allowVariadic, argTypes, argNames, argAttrs, isVariadic)) return failure(); @@ -139,10 +137,9 @@ return success(); } -void mlir::impl::addArgAndResultAttrs( - Builder &builder, OperationState &result, - ArrayRef> argAttrs, - ArrayRef> resultAttrs) { +void mlir::impl::addArgAndResultAttrs(Builder &builder, OperationState &result, + ArrayRef argAttrs, + ArrayRef resultAttrs) { // Add the attributes to the function arguments. SmallString<8> attrNameBuf; for (unsigned i = 0, e = argAttrs.size(); i != e; ++i) @@ -164,8 +161,8 @@ bool allowVariadic, mlir::impl::FuncTypeBuilder funcTypeBuilder) { SmallVector entryArgs; - SmallVector, 4> argAttrs; - SmallVector, 4> resultAttrs; + SmallVector argAttrs; + SmallVector resultAttrs; SmallVector argTypes; SmallVector resultTypes; auto &builder = parser.getBuilder(); diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -328,6 +328,7 @@ BoolAttr falseAttr, trueAttr; UnitAttr unitAttr; UnknownLoc unknownLocAttr; + DictionaryAttr emptyDictionaryAttr; public: MLIRContextImpl() : identifiers(identifierAllocator) {} @@ -388,6 +389,9 @@ /// Unknown Location Attribute. impl->unknownLocAttr = AttributeUniquer::get( this, StandardAttributes::UnknownLocation); + /// The empty dictionary attribute. + impl->emptyDictionaryAttr = AttributeUniquer::get( + this, StandardAttributes::Dictionary, ArrayRef()); } MLIRContext::~MLIRContext() {} @@ -742,6 +746,11 @@ return context->getImpl().unknownLocAttr; } +/// Return empty context. +DictionaryAttr DictionaryAttr::getEmpty(MLIRContext *context) { + return context->getImpl().emptyDictionaryAttr; +} + //===----------------------------------------------------------------------===// // AffineMap uniquing //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -70,9 +70,9 @@ /// Create a new Operation from operation state. Operation *Operation::create(const OperationState &state) { - return Operation::create( - state.location, state.name, state.types, state.operands, - MutableDictionaryAttr(state.attributes), state.successors, state.regions); + return Operation::create(state.location, state.name, state.types, + state.operands, state.attributes, state.successors, + state.regions); } /// Create a new Operation with the specific fields. diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -18,6 +18,146 @@ #include "mlir/IR/StandardTypes.h" using namespace mlir; +//===----------------------------------------------------------------------===// +// NamedAttrList +//===----------------------------------------------------------------------===// + +NamedAttrList::NamedAttrList(ArrayRef attributes) { + assign(attributes.begin(), attributes.end()); +} + +NamedAttrList::NamedAttrList(const_iterator in_start, const_iterator in_end) { + assign(in_start, in_end); +} + +ArrayRef NamedAttrList::getAttrs() const { return attrs; } + +DictionaryAttr NamedAttrList::getDictionary(MLIRContext *context) const { + if (!isSorted()) { + DictionaryAttr::sortInPlace(attrs); + dictionarySorted.setPointerAndInt(nullptr, true); + } + if (!dictionarySorted.getPointer()) + dictionarySorted.setPointer(DictionaryAttr::getWithSorted(attrs, context)); + return dictionarySorted.getPointer().cast(); +} + +NamedAttrList::operator MutableDictionaryAttr() const { + if (attrs.empty()) + return MutableDictionaryAttr(); + return getDictionary(attrs.front().second.getContext()); +} + +/// Add an attribute with the specified name. +void NamedAttrList::append(StringRef name, Attribute attr) { + append(Identifier::get(name, attr.getContext()), attr); +} + +/// Add an attribute with the specified name. +void NamedAttrList::append(Identifier name, Attribute attr) { + push_back({name, attr}); +} + +/// Add an array of named attributes. +void NamedAttrList::append(ArrayRef newAttributes) { + append(newAttributes.begin(), newAttributes.end()); +} + +/// Add a range of named attributes. +void NamedAttrList::append(const_iterator in_start, const_iterator in_end) { + // TODO: expand to handle case where values appended are in order & after + // end of current list. + dictionarySorted.setPointerAndInt(nullptr, false); + attrs.append(in_start, in_end); +} + +/// Replaces the attributes with new list of attributes. +void NamedAttrList::assign(const_iterator in_start, const_iterator in_end) { + DictionaryAttr::sort(ArrayRef{in_start, in_end}, attrs); + dictionarySorted.setPointerAndInt(nullptr, true); +} + +void NamedAttrList::push_back(NamedAttribute newAttribute) { + if (isSorted()) + dictionarySorted.setInt( + attrs.empty() || + strcmp(attrs.back().first.data(), newAttribute.first.data()) < 0); + dictionarySorted.setPointer(nullptr); + attrs.push_back(newAttribute); +} + +/// Helper function to find attribute in possible sorted vector of +/// NamedAttributes. +template +static auto *findAttr(SmallVectorImpl &attrs, T name, + bool sorted) { + if (!sorted) { + return llvm::find_if( + attrs, [name](NamedAttribute attr) { return attr.first == name; }); + } + + auto *it = llvm::lower_bound(attrs, name); + if (it->first != name) + return attrs.end(); + return it; +} + +/// Return the specified attribute if present, null otherwise. +Attribute NamedAttrList::get(StringRef name) const { + auto *it = findAttr(attrs, name, isSorted()); + return it != attrs.end() ? it->second : nullptr; +} + +/// Return the specified attribute if present, null otherwise. +Attribute NamedAttrList::get(Identifier name) const { + auto *it = findAttr(attrs, name, isSorted()); + return it != attrs.end() ? it->second : nullptr; +} + +/// Return the specified named attribute if present, None otherwise. +Optional NamedAttrList::getNamed(StringRef name) const { + auto *it = findAttr(attrs, name, isSorted()); + return it != attrs.end() ? *it : Optional(); +} +Optional NamedAttrList::getNamed(Identifier name) const { + auto *it = findAttr(attrs, name, isSorted()); + return it != attrs.end() ? *it : Optional(); +} + +/// If the an attribute exists with the specified name, change it to the new +/// value. Otherwise, add a new attribute with the specified name/value. +void NamedAttrList::set(Identifier name, Attribute value) { + assert(value && "attributes may never be null"); + + // Look for an existing value for the given name, and set it in-place. + auto *it = findAttr(attrs, name, isSorted()); + if (it != attrs.end()) { + // Bail out early if the value is the same as what we already have. + if (it->second == value) + return; + dictionarySorted.setPointer(nullptr); + it->second = value; + return; + } + + // Otherwise, insert the new attribute into its sorted position. + it = llvm::lower_bound(attrs, name); + dictionarySorted.setPointer(nullptr); + attrs.insert(it, {name, value}); +} +void NamedAttrList::set(StringRef name, Attribute value) { + assert(value && "setting null attribute not supported"); + return set(mlir::Identifier::get(name, value.getContext()), value); +} + +NamedAttrList & +NamedAttrList::operator=(const SmallVectorImpl &rhs) { + assign(rhs.begin(), rhs.end()); + return *this; +} + +NamedAttrList::operator ArrayRef() const { return attrs; } + //===----------------------------------------------------------------------===// // OperationState //===----------------------------------------------------------------------===// @@ -133,7 +273,7 @@ // Shift all operands down if the operand to remove is not at the end. if (start != storage.numOperands) { - auto indexIt = std::next(operands.begin(), start); + auto *indexIt = std::next(operands.begin(), start); std::rotate(indexIt, std::next(indexIt, length), operands.end()); } for (unsigned i = 0; i != length; ++i) @@ -416,8 +556,8 @@ // Hash operations based upon their: // - Operation Name // - Attributes - llvm::hash_code hash = llvm::hash_combine( - op->getName(), op->getMutableAttrDict().getDictionary()); + llvm::hash_code hash = + llvm::hash_combine(op->getName(), op->getMutableAttrDict()); // - Result Types ArrayRef resultTypes = op->getResultTypes(); diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp --- a/mlir/lib/IR/SymbolTable.cpp +++ b/mlir/lib/IR/SymbolTable.cpp @@ -385,9 +385,9 @@ Operation *op, function_ref)> callback) { // Check to see if the operation has any attributes. - DictionaryAttr attrDict = op->getMutableAttrDict().getDictionary(); - if (!attrDict) + if (op->getMutableAttrDict().empty()) return WalkResult::advance(); + DictionaryAttr attrDict = op->getAttrDictionary(); // A worklist of a container attribute and the current index into the held // attribute list. @@ -803,7 +803,7 @@ // Generate a new attribute dictionary for the current operation by replacing // references to the old symbol. auto generateNewAttrDict = [&] { - auto oldDict = curOp->getMutableAttrDict().getDictionary(); + auto oldDict = curOp->getAttrDictionary(); auto newDict = rebuildAttrAfterRAUW(oldDict, accessChains, /*depth=*/0); return newDict.cast(); }; diff --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp --- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp +++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp @@ -24,11 +24,11 @@ LogicalResult mlir::detail::inferReturnTensorTypes( function_ref location, ValueRange operands, - ArrayRef attributes, RegionRange regions, + DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &retComponents)> componentTypeFn, MLIRContext *context, Optional location, ValueRange operands, - ArrayRef attributes, RegionRange regions, + DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { SmallVector retComponents; if (failed(componentTypeFn(context, location, operands, attributes, regions, @@ -49,9 +49,9 @@ LogicalResult mlir::detail::verifyInferredResultTypes(Operation *op) { SmallVector inferredReturnTypes; auto retTypeFn = cast(op); - if (failed(retTypeFn.inferReturnTypes(op->getContext(), op->getLoc(), - op->getOperands(), op->getAttrs(), - op->getRegions(), inferredReturnTypes))) + if (failed(retTypeFn.inferReturnTypes( + op->getContext(), op->getLoc(), op->getOperands(), + op->getAttrDictionary(), op->getRegions(), inferredReturnTypes))) return failure(); if (!retTypeFn.isCompatibleReturnTypes(inferredReturnTypes, op->getResultTypes())) diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -275,7 +275,7 @@ Attribute parseAttribute(Type type = {}); /// Parse an attribute dictionary. - ParseResult parseAttributeDict(SmallVectorImpl &attributes); + ParseResult parseAttributeDict(NamedAttrList &attributes); /// Parse an extended attribute. Attribute parseExtendedAttr(Type type); @@ -1569,10 +1569,10 @@ // Parse a dictionary attribute. case Token::l_brace: { - SmallVector elements; + NamedAttrList elements; if (parseAttributeDict(elements)) return nullptr; - return builder.getDictionaryAttr(elements); + return elements.getDictionary(getContext()); } // Parse an extended attribute, i.e. alias or dialect attribute. @@ -1671,8 +1671,7 @@ /// | `{` attribute-entry (`,` attribute-entry)* `}` /// attribute-entry ::= (bare-id | string-literal) `=` attribute-value /// -ParseResult -Parser::parseAttributeDict(SmallVectorImpl &attributes) { +ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) { if (parseToken(Token::l_brace, "expected '{' in attribute dictionary")) return failure(); @@ -1701,7 +1700,6 @@ auto attr = parseAttribute(); if (!attr) return failure(); - attributes.push_back({*nameId, attr}); return success(); }; @@ -4217,7 +4215,7 @@ /// also adds the attribute to the specified attribute list with the specified /// name. ParseResult parseAttribute(Attribute &result, Type type, StringRef attrName, - SmallVectorImpl &attrs) override { + NamedAttrList &attrs) override { result = parser.parseAttribute(type); if (!result) return failure(); @@ -4227,8 +4225,7 @@ } /// Parse a named dictionary into 'result' if it is present. - ParseResult - parseOptionalAttrDict(SmallVectorImpl &result) override { + ParseResult parseOptionalAttrDict(NamedAttrList &result) override { if (parser.getToken().isNot(Token::l_brace)) return success(); return parser.parseAttributeDict(result); @@ -4236,8 +4233,7 @@ /// Parse a named dictionary into 'result' if the `attributes` keyword is /// present. - ParseResult parseOptionalAttrDictWithKeyword( - SmallVectorImpl &result) override { + ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result) override { if (failed(parseOptionalKeyword("attributes"))) return success(); return parser.parseAttributeDict(result); @@ -4285,9 +4281,8 @@ /// Parse an optional @-identifier and store it (without the '@' symbol) in a /// string attribute named 'attrName'. - ParseResult - parseOptionalSymbolName(StringAttr &result, StringRef attrName, - SmallVectorImpl &attrs) override { + ParseResult parseOptionalSymbolName(StringAttr &result, StringRef attrName, + NamedAttrList &attrs) override { Token atToken = parser.getToken(); if (atToken.isNot(Token::at_identifier)) return failure(); @@ -4435,7 +4430,7 @@ /// Parse an AffineMap of SSA ids. ParseResult parseAffineMapOfSSAIds(SmallVectorImpl &operands, Attribute &mapAttr, StringRef attrName, - SmallVectorImpl &attrs, + NamedAttrList &attrs, Delimiter delimiter) override { SmallVector dimOperands; SmallVector symOperands; diff --git a/mlir/lib/Pass/IRPrinting.cpp b/mlir/lib/Pass/IRPrinting.cpp --- a/mlir/lib/Pass/IRPrinting.cpp +++ b/mlir/lib/Pass/IRPrinting.cpp @@ -33,9 +33,7 @@ // - Operation pointer addDataToHash(hasher, op); // - Attributes - addDataToHash( - hasher, - op->getMutableAttrDict().getDictionary().getAsOpaquePointer()); + addDataToHash(hasher, op->getMutableAttrDict()); // - Blocks in Regions for (Region ®ion : op->getRegions()) { for (Block &block : region) { diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -325,7 +325,7 @@ LogicalResult mlir::OpWithInferTypeInterfaceOp::inferReturnTypes( MLIRContext *, Optional location, ValueRange operands, - ArrayRef attributes, RegionRange regions, + DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { if (operands[0].getType() != operands[1].getType()) { return emitOptionalError(location, "operand type mismatch ", @@ -338,7 +338,7 @@ LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( MLIRContext *context, Optional location, ValueRange operands, - ArrayRef attributes, RegionRange regions, + DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { // Create return type consisting of the last element of the first operand. auto operandType = *operands.getTypes().begin(); diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -72,9 +72,9 @@ for (int j = 0; j < e; ++j) { std::array values = {{fop.getArgument(i), fop.getArgument(j)}}; SmallVector inferredReturnTypes; - if (succeeded(OpTy::inferReturnTypes(context, llvm::None, values, - op->getAttrs(), op->getRegions(), - inferredReturnTypes))) { + if (succeeded(OpTy::inferReturnTypes( + context, llvm::None, values, op->getAttrDictionary(), + op->getRegions(), inferredReturnTypes))) { OperationState state(location, OpTy::getOperationName()); // TODO(jpienaar): Expand to regions. OpTy::build(b, state, values, op->getAttrs()); diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -775,7 +775,8 @@ body << formatv(R"( SmallVector inferredReturnTypes; if (succeeded({0}::inferReturnTypes(odsBuilder.getContext(), - {1}.location, {1}.operands, {1}.attributes, + {1}.location, {1}.operands, + {1}.attributes.getDictionary({1}.getContext()), /*regions=*/{{}, inferredReturnTypes))) {1}.addTypes(inferredReturnTypes); else @@ -867,13 +868,48 @@ opClass.newMethod("void", "build", formatv(params, builderOpState).str(), OpMethod::MP_Static); auto &body = m.body(); + + int numResults = op.getNumResults(); + int numVariadicResults = op.getNumVariableLengthResults(); + int numNonVariadicResults = numResults - numVariadicResults; + + int numOperands = op.getNumOperands(); + int numVariadicOperands = op.getNumVariableLengthOperands(); + int numNonVariadicOperands = numOperands - numVariadicOperands; + + // Operands + if (numVariadicOperands == 0 || numNonVariadicOperands != 0) + body << " assert(operands.size()" + << (numVariadicOperands != 0 ? " >= " : " == ") + << numNonVariadicOperands + << "u && \"mismatched number of parameters\");\n"; + body << " " << builderOpState << ".addOperands(operands);\n"; + body << " " << builderOpState << ".addAttributes(attributes);\n"; + + // Create the correct number of regions + if (int numRegions = op.getNumRegions()) { + body << llvm::formatv( + " for (unsigned i = 0; i != {0}; ++i)\n", + (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions))); + body << " (void)" << builderOpState << ".addRegion();\n"; + } + + // Result types body << formatv(R"( SmallVector inferredReturnTypes; if (succeeded({0}::inferReturnTypes(odsBuilder.getContext(), - {1}.location, operands, attributes, - /*regions=*/{{}, inferredReturnTypes))) - build(odsBuilder, odsState, inferredReturnTypes, operands, attributes); - else + {1}.location, operands, + {1}.attributes.getDictionary({1}.getContext()), + /*regions=*/{{}, inferredReturnTypes))) {{)", + opClass.getClassName(), builderOpState); + if (numVariadicResults == 0 || numNonVariadicResults != 0) + body << " assert(inferredReturnTypes.size()" + << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults + << "u && \"mismatched number of return types\");\n"; + body << " " << builderOpState << ".addTypes(inferredReturnTypes);"; + + body << formatv(R"( + } else llvm::report_fatal_error("Failed to infer result type(s).");)", opClass.getClassName(), builderOpState); } diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -372,7 +372,7 @@ const char *const enumAttrParserCode = R"( { StringAttr attrVal; - SmallVector attrStorage; + NamedAttrList attrStorage; auto loc = parser.getCurrentLocation(); if (parser.parseAttribute(attrVal, parser.getBuilder().getNoneType(), "{0}", attrStorage))