diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -308,9 +308,16 @@ size_t size() const; /// Sorts the NamedAttributes in the array ordered by name as expected by - /// getWithSorted. + /// getWithSorted and returns whether the values needed to be sorted. /// Requires: uniquely named attributes. - static void sort(SmallVectorImpl &array); + static bool sort(ArrayRef values, + SmallVectorImpl &storage); + + /// Sorts the NamedAttributes in the array ordered by name as expected by + /// getWithSorted in place on an array and returns whether the values needed + /// to be sorted. + /// Requires: uniquely named attributes. + static bool sortInPlace(SmallVectorImpl &array); /// Methods for supporting type inquiry through isa, cast, and dyn_cast. static bool kindof(unsigned kind) { @@ -1504,9 +1511,8 @@ return attrs == other.attrs; } - /// Return the underlying dictionary attribute. This may be null, if this list - /// has no attributes. - DictionaryAttr getDictionary() const { return attrs; } + /// Return the underlying dictionary attribute. + DictionaryAttr getDictionary(MLIRContext *context) const; /// Return all of the attributes on this operation. ArrayRef getAttrs() const; @@ -1532,6 +1538,8 @@ /// value indicates whether the attribute was present or not. RemoveResult remove(Identifier name); + bool empty() const { return attrs == nullptr; } + private: DictionaryAttr attrs; }; diff --git a/mlir/include/mlir/IR/FunctionImplementation.h b/mlir/include/mlir/IR/FunctionImplementation.h --- a/mlir/include/mlir/IR/FunctionImplementation.h +++ b/mlir/include/mlir/IR/FunctionImplementation.h @@ -38,8 +38,8 @@ /// Internally, argument and result attributes are stored as dict attributes /// with special names given by getResultAttrName, getArgumentAttrName. void addArgAndResultAttrs(Builder &builder, OperationState &result, - ArrayRef> argAttrs, - ArrayRef> resultAttrs); + ArrayRef argAttrs, + ArrayRef resultAttrs); /// Callback type for `parseFunctionLikeOp`, the callback should produce the /// type that will be associated with a function-like operation from lists of @@ -53,13 +53,13 @@ /// indicates whether functions with variadic arguments are supported. The /// trailing arguments are populated by this function with names, types and /// attributes of the arguments and those of the results. -ParseResult parseFunctionSignature( - OpAsmParser &parser, bool allowVariadic, - SmallVectorImpl &argNames, - SmallVectorImpl &argTypes, - SmallVectorImpl> &argAttrs, bool &isVariadic, - SmallVectorImpl &resultTypes, - SmallVectorImpl> &resultAttrs); +ParseResult +parseFunctionSignature(OpAsmParser &parser, bool allowVariadic, + SmallVectorImpl &argNames, + SmallVectorImpl &argTypes, + SmallVectorImpl &argAttrs, + bool &isVariadic, SmallVectorImpl &resultTypes, + SmallVectorImpl &resultAttrs); /// Parser implementation for function-like operations. Uses /// `funcTypeBuilder` to construct the custom function type given lists of diff --git a/mlir/include/mlir/IR/FunctionSupport.h b/mlir/include/mlir/IR/FunctionSupport.h --- a/mlir/include/mlir/IR/FunctionSupport.h +++ b/mlir/include/mlir/IR/FunctionSupport.h @@ -517,10 +517,15 @@ MutableDictionaryAttr attributes) { assert(index < getNumArguments() && "invalid argument number"); SmallString<8> nameOut; - if (auto newAttr = attributes.getDictionary()) + if (attributes.getAttrs().empty()) { + static_cast(this)->removeAttr( + getArgAttrName(index, nameOut)); + } else { + auto newAttr = attributes.getDictionary( + attributes.getAttrs().front().second.getContext()); return this->getOperation()->setAttr(getArgAttrName(index, nameOut), newAttr); - static_cast(this)->removeAttr(getArgAttrName(index, nameOut)); + } } /// If the an attribute exists with the specified name, change it to the new @@ -533,7 +538,7 @@ attrDict.set(name, value); // If the attribute changed, then set the new arg attribute list. - if (curAttr != attrDict.getDictionary()) + if (curAttr != attrDict.getDictionary(value.getContext())) setArgAttrs(index, attrDict); } @@ -574,11 +579,15 @@ unsigned index, MutableDictionaryAttr attributes) { assert(index < getNumResults() && "invalid result number"); SmallString<8> nameOut; - if (auto newAttr = attributes.getDictionary()) + if (attributes.getAttrs().empty()) { + static_cast(this)->removeAttr( + getResultAttrName(index, nameOut)); + } else { + auto newAttr = attributes.getDictionary( + attributes.getAttrs().front().second.getContext()); return this->getOperation()->setAttr(getResultAttrName(index, nameOut), newAttr); - static_cast(this)->removeAttr( - getResultAttrName(index, nameOut)); + } } /// If the an attribute exists with the specified name, change it to the new @@ -591,7 +600,7 @@ attrDict.set(name, value); // If the attribute changed, then set the new arg attribute list. - if (curAttr != attrDict.getDictionary()) + if (curAttr != attrDict.getDictionary(value.getContext())) setResultAttrs(index, attrDict); } diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -366,28 +366,28 @@ /// Parse an arbitrary attribute and return it in result. This also adds the /// attribute to the specified attribute list with the specified name. ParseResult parseAttribute(Attribute &result, StringRef attrName, - SmallVectorImpl &attrs) { + NamedAttrList &attrs) { return parseAttribute(result, Type(), attrName, attrs); } /// Parse an attribute of a specific kind and type. template ParseResult parseAttribute(AttrType &result, StringRef attrName, - SmallVectorImpl &attrs) { + NamedAttrList &attrs) { return parseAttribute(result, Type(), attrName, attrs); } /// Parse an arbitrary attribute of a given type and return it in result. This /// also adds the attribute to the specified attribute list with the specified /// name. - virtual ParseResult - parseAttribute(Attribute &result, Type type, StringRef attrName, - SmallVectorImpl &attrs) = 0; + virtual ParseResult parseAttribute(Attribute &result, Type type, + StringRef attrName, + NamedAttrList &attrs) = 0; /// Parse an attribute of a specific kind and type. template ParseResult parseAttribute(AttrType &result, Type type, StringRef attrName, - SmallVectorImpl &attrs) { + NamedAttrList &attrs) { llvm::SMLoc loc = getCurrentLocation(); // Parse any kind of attribute. @@ -404,13 +404,12 @@ } /// Parse a named dictionary into 'result' if it is present. - virtual ParseResult - parseOptionalAttrDict(SmallVectorImpl &result) = 0; + virtual ParseResult parseOptionalAttrDict(NamedAttrList &result) = 0; /// Parse a named dictionary into 'result' if the `attributes` keyword is /// present. virtual ParseResult - parseOptionalAttrDictWithKeyword(SmallVectorImpl &result) = 0; + parseOptionalAttrDictWithKeyword(NamedAttrList &result) = 0; /// Parse an affine map instance into 'map'. virtual ParseResult parseAffineMap(AffineMap &map) = 0; @@ -425,7 +424,7 @@ /// Parse an @-identifier and store it (without the '@' symbol) in a string /// attribute named 'attrName'. ParseResult parseSymbolName(StringAttr &result, StringRef attrName, - SmallVectorImpl &attrs) { + NamedAttrList &attrs) { if (failed(parseOptionalSymbolName(result, attrName, attrs))) return emitError(getCurrentLocation()) << "expected valid '@'-identifier for symbol name"; @@ -434,9 +433,9 @@ /// Parse an optional @-identifier and store it (without the '@' symbol) in a /// string attribute named 'attrName'. - virtual ParseResult - parseOptionalSymbolName(StringAttr &result, StringRef attrName, - SmallVectorImpl &attrs) = 0; + virtual ParseResult parseOptionalSymbolName(StringAttr &result, + StringRef attrName, + NamedAttrList &attrs) = 0; //===--------------------------------------------------------------------===// // Operand Parsing @@ -552,8 +551,7 @@ /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol. virtual ParseResult parseAffineMapOfSSAIds(SmallVectorImpl &operands, Attribute &map, - StringRef attrName, - SmallVectorImpl &attrs, + StringRef attrName, NamedAttrList &attrs, Delimiter delimiter = Delimiter::Square) = 0; //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -294,6 +294,11 @@ /// Return all of the attributes on this operation. ArrayRef getAttrs() { return attrs.getAttrs(); } + /// Return all of the attributes on this operation as a DictionaryAttr. + DictionaryAttr getAttrDictionary() { + return attrs.getDictionary(getContext()); + } + /// Return mutable container of all the attributes on this operation. MutableDictionaryAttr &getMutableAttrDict() { return attrs; } diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -200,6 +200,97 @@ bool (&hasRawTrait)(TypeID traitID); }; +//===----------------------------------------------------------------------===// +// NamedAttrList +//===----------------------------------------------------------------------===// + +/// NamedAttrList is array of NamedAttributes that tracks whether it is sorted +/// and does some basic work to remain sorted. +class NamedAttrList { +public: + using const_iterator = SmallVectorImpl::const_iterator; + using const_reference = const NamedAttribute &; + using reference = NamedAttribute &; + using size_type = size_t; + + NamedAttrList() : dictionarySorted({}, true) {} + NamedAttrList(ArrayRef attributes); + NamedAttrList(const_iterator in_start, const_iterator in_end); + + bool operator!=(const NamedAttrList &other) const { + return !(*this == other); + } + bool operator==(const NamedAttrList &other) const { + return attrs == other.attrs; + } + + /// Add an attribute with the specified name. + void append(StringRef name, Attribute attr); + + /// Add an attribute with the specified name. + void append(Identifier name, Attribute attr); + + /// Add an array of named attributes. + void append(ArrayRef newAttributes); + + /// Add a range of named attributes. + void append(const_iterator in_start, const_iterator in_end); + + /// Replaces the attributes with new list of attributes. + void assign(const_iterator in_start, const_iterator in_end); + + /// Replaces the attributes with new list of attributes. + void assign(ArrayRef range) { + append(range.begin(), range.end()); + } + + bool empty() const { return attrs.empty(); } + + void reserve(size_type N) { attrs.reserve(N); } + + /// Add an attribute with the specified name. + void push_back(NamedAttribute newAttribute); + + /// Pop last element from list. + void pop_back() { attrs.pop_back(); } + + /// Return a dictionary attribute for the underlying dictionary. This will + /// return an empty dictionary attribute if empty rather than null. + DictionaryAttr getDictionary(MLIRContext *context) const; + + /// Return all of the attributes on this operation. + ArrayRef getAttrs() const; + + /// Return the specified attribute if present, null otherwise. + Attribute get(Identifier name) const; + Attribute get(StringRef name) const; + + /// Return the specified named attribute if present, None otherwise. + Optional getNamed(StringRef name) const; + Optional getNamed(Identifier name) const; + + /// If the an attribute exists with the specified name, change it to the new + /// value. Otherwise, add a new attribute with the specified name/value. + void set(Identifier name, Attribute value); + void set(StringRef name, Attribute value); + + const_iterator begin() const { return attrs.begin(); } + const_iterator end() const { return attrs.end(); } + + NamedAttrList &operator=(const SmallVectorImpl &rhs); + operator ArrayRef() const; + operator MutableDictionaryAttr() const; + +private: + // These are marked mutable as they may be modified (e.g., sorted) + mutable SmallVector attrs; + // Pair with cached DictionaryAttr and status of whether attrs is sorted. + // Note: just because sorted does not mean a DictionaryAttr has been created + // but the case where there is a DictionaryAttr but attrs isn't sorted should + // not occur. + mutable llvm::PointerIntPair dictionarySorted; +}; + //===----------------------------------------------------------------------===// // OperationName //===----------------------------------------------------------------------===// @@ -268,7 +359,7 @@ SmallVector operands; /// Types of the results of this operation. SmallVector types; - SmallVector attributes; + NamedAttrList attributes; /// Successors of this operation and their respective operands. SmallVector successors; /// Regions that the op will hold. @@ -302,12 +393,12 @@ /// Add an attribute with the specified name. void addAttribute(Identifier name, Attribute attr) { - attributes.push_back({name, attr}); + attributes.append(name, attr); } /// Add an array of named attributes. void addAttributes(ArrayRef newAttributes) { - attributes.append(newAttributes.begin(), newAttributes.end()); + attributes.append(newAttributes); } /// Add an array of successors. @@ -328,7 +419,7 @@ void addRegion(std::unique_ptr &®ion); /// Get the context held by this operation state. - MLIRContext *getContext() { return location->getContext(); } + MLIRContext *getContext() const { return location->getContext(); } }; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h --- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.h +++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.h @@ -83,11 +83,11 @@ LogicalResult inferReturnTensorTypes( function_ref location, ValueRange operands, - ArrayRef attributes, RegionRange regions, + DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &retComponents)> componentTypeFn, MLIRContext *context, Optional location, ValueRange operands, - ArrayRef attributes, RegionRange regions, + DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnTypes); /// Verifies that the inferred result types match the actual result types for @@ -107,7 +107,7 @@ public: static LogicalResult inferReturnTypes(MLIRContext *context, Optional location, - ValueRange operands, ArrayRef attributes, + ValueRange operands, DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { return ::mlir::detail::inferReturnTensorTypes( diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td --- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td +++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td @@ -40,7 +40,7 @@ /*args=*/(ins "MLIRContext*":$context, "Optional":$location, "ValueRange":$operands, - "ArrayRef":$attributes, + "DictionaryAttr":$attributes, "RegionRange":$regions, "SmallVectorImpl&":$inferredReturnTypes) >, @@ -92,7 +92,7 @@ /*args=*/(ins "MLIRContext*":$context, "Optional":$location, "ValueRange":$operands, - "ArrayRef":$attributes, + "DictionaryAttr":$attributes, "RegionRange":$regions, "SmallVectorImpl&": $inferredReturnShapes) diff --git a/mlir/lib/Analysis/CallGraph.cpp b/mlir/lib/Analysis/CallGraph.cpp --- a/mlir/lib/Analysis/CallGraph.cpp +++ b/mlir/lib/Analysis/CallGraph.cpp @@ -179,7 +179,8 @@ auto *parentOp = callableRegion->getParentOp(); os << "'" << callableRegion->getParentOp()->getName() << "' - Region #" << callableRegion->getRegionNumber(); - if (auto attrs = parentOp->getMutableAttrDict().getDictionary()) + auto attrs = parentOp->getAttrDictionary(); + if (!attrs.empty()) os << " : " << attrs; }; diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -2453,7 +2453,7 @@ return failure(); AffineMapAttr stepsMapAttr; - SmallVector stepsAttrs; + NamedAttrList stepsAttrs; SmallVector stepsMapOperands; if (failed(parser.parseOptionalKeyword("step"))) { SmallVector steps(ivs.size(), 1); diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -538,8 +538,8 @@ /// function-attributes? region static ParseResult parseGPUFuncOp(OpAsmParser &parser, OperationState &result) { SmallVector entryArgs; - SmallVector, 1> argAttrs; - SmallVector, 1> resultAttrs; + SmallVector argAttrs; + SmallVector resultAttrs; SmallVector argTypes; SmallVector resultTypes; bool isVariadic; diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -92,8 +92,8 @@ predicateValue = static_cast(predicate.getValue()); } - result.attributes[0].second = - parser.getBuilder().getI64IntegerAttr(predicateValue); + result.attributes.set("predicate", + parser.getBuilder().getI64IntegerAttr(predicateValue)); // The result type is either i1 or a vector type if the inputs are // vectors. @@ -1186,7 +1186,7 @@ "expected as many argument attribute lists as arguments"); SmallString<8> argAttrName; for (unsigned i = 0; i < numInputs; ++i) - if (auto argDict = argAttrs[i].getDictionary()) + if (auto argDict = argAttrs[i].getDictionary(builder.getContext())) result.addAttribute(getArgAttrName(i, argAttrName), argDict); } @@ -1249,8 +1249,8 @@ StringAttr nameAttr; SmallVector entryArgs; - SmallVector, 1> argAttrs; - SmallVector, 1> resultAttrs; + SmallVector argAttrs; + SmallVector resultAttrs; SmallVector argTypes; SmallVector resultTypes; bool isVariadic; diff --git a/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp @@ -353,8 +353,8 @@ } bool isSameAttrList(spirv::StoreOp lhs, spirv::StoreOp rhs) const { - return lhs.getOperation()->getMutableAttrDict().getDictionary() == - rhs.getOperation()->getMutableAttrDict().getDictionary(); + return lhs.getOperation()->getAttrDictionary() == + rhs.getOperation()->getAttrDictionary(); } diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -105,7 +105,7 @@ parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, StringRef attrName = spirv::attributeName()) { Attribute attrVal; - SmallVector attr; + NamedAttrList attr; auto loc = parser.getCurrentLocation(); if (parser.parseAttribute(attrVal, parser.getBuilder().getNoneType(), attrName, attr)) { @@ -1019,7 +1019,7 @@ // Parse the optional branch weights. if (succeeded(parser.parseOptionalLSquare())) { IntegerAttr trueWeight, falseWeight; - SmallVector weights; + NamedAttrList weights; auto i32Type = builder.getIntegerType(32); if (parser.parseAttribute(trueWeight, i32Type, "weight", weights) || @@ -1443,7 +1443,7 @@ // The name of the interface variable attribute isnt important auto attrName = "var_symbol"; FlatSymbolRefAttr var; - SmallVector attrs; + NamedAttrList attrs; if (parser.parseAttribute(var, Type(), attrName, attrs)) { return failure(); } @@ -1497,7 +1497,7 @@ SmallVector values; Type i32Type = parser.getBuilder().getIntegerType(32); while (!parser.parseOptionalComma()) { - SmallVector attr; + NamedAttrList attr; Attribute value; if (parser.parseAttribute(value, i32Type, "value", attr)) { return failure(); @@ -1529,8 +1529,8 @@ static ParseResult parseFuncOp(OpAsmParser &parser, OperationState &state) { SmallVector entryArgs; - SmallVector, 4> argAttrs; - SmallVector, 4> resultAttrs; + SmallVector argAttrs; + SmallVector resultAttrs; SmallVector argTypes; SmallVector resultTypes; auto &builder = parser.getBuilder(); diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -92,10 +92,11 @@ // BroadcastOp //===----------------------------------------------------------------------===// -LogicalResult BroadcastOp::inferReturnTypes( - MLIRContext *context, Optional location, ValueRange operands, - ArrayRef attributes, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { +LogicalResult +BroadcastOp::inferReturnTypes(MLIRContext *context, Optional location, + ValueRange operands, DictionaryAttr attributes, + RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { inferredReturnTypes.push_back(ShapeType::get(context)); return success(); } @@ -137,7 +138,7 @@ // shape as an ArrayAttr. // TODO: Implement custom parser and maybe make syntax a bit more concise. Attribute extentsRaw; - SmallVector dummy; + NamedAttrList dummy; if (parser.parseAttribute(extentsRaw, "dummy", dummy)) return failure(); auto extentsArray = extentsRaw.dyn_cast(); @@ -159,10 +160,11 @@ OpFoldResult ConstShapeOp::fold(ArrayRef) { return shape(); } -LogicalResult ConstShapeOp::inferReturnTypes( - MLIRContext *context, Optional location, ValueRange operands, - ArrayRef attributes, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { +LogicalResult +ConstShapeOp::inferReturnTypes(MLIRContext *context, + Optional location, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { inferredReturnTypes.push_back(ShapeType::get(context)); return success(); } @@ -171,10 +173,11 @@ // ConstSizeOp //===----------------------------------------------------------------------===// -LogicalResult ConstSizeOp::inferReturnTypes( - MLIRContext *context, Optional location, ValueRange operands, - ArrayRef attributes, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { +LogicalResult +ConstSizeOp::inferReturnTypes(MLIRContext *context, Optional location, + ValueRange operands, DictionaryAttr attributes, + RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { inferredReturnTypes.push_back(SizeType::get(context)); return success(); } @@ -183,10 +186,11 @@ // ShapeOfOp //===----------------------------------------------------------------------===// -LogicalResult ShapeOfOp::inferReturnTypes( - MLIRContext *context, Optional location, ValueRange operands, - ArrayRef attributes, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { +LogicalResult +ShapeOfOp::inferReturnTypes(MLIRContext *context, Optional location, + ValueRange operands, DictionaryAttr attributes, + RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { inferredReturnTypes.push_back(ShapeType::get(context)); return success(); } @@ -203,10 +207,11 @@ // SplitAtOp //===----------------------------------------------------------------------===// -LogicalResult SplitAtOp::inferReturnTypes( - MLIRContext *context, Optional location, ValueRange operands, - ArrayRef attributes, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { +LogicalResult +SplitAtOp::inferReturnTypes(MLIRContext *context, Optional location, + ValueRange operands, DictionaryAttr attributes, + RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { auto shapeType = ShapeType::get(context); inferredReturnTypes.push_back(shapeType); inferredReturnTypes.push_back(shapeType); @@ -238,10 +243,11 @@ // ConcatOp //===----------------------------------------------------------------------===// -LogicalResult ConcatOp::inferReturnTypes( - MLIRContext *context, Optional location, ValueRange operands, - ArrayRef attributes, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { +LogicalResult +ConcatOp::inferReturnTypes(MLIRContext *context, Optional location, + ValueRange operands, DictionaryAttr attributes, + RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { auto shapeType = ShapeType::get(context); inferredReturnTypes.push_back(shapeType); return success(); diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -478,7 +478,7 @@ static ParseResult parseExtractOp(OpAsmParser &parser, OperationState &result) { llvm::SMLoc attributeLoc, typeLoc; - SmallVector attrs; + NamedAttrList attrs; OpAsmParser::OperandType vector; Type type; Attribute attr; diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -111,21 +111,31 @@ // Specialize for the common case. switch (value.size()) { case 0: + // Zero already sorted. + break; case 1: - // Zero or one elements are already sorted. + // One already sorted but may need to be copied. + if (!inPlace) + storage.assign({value[0]}); break; - case 2: + case 2: { assert(value[0].first != value[1].first && "DictionaryAttr element names must be unique"); - if (compareNamedAttributes(&value[0], &value[1]) > 0) { - if (inPlace) + bool isSorted = compareNamedAttributes(&value[0], &value[1]) < 0; + if (inPlace) { + if (!isSorted) std::swap(storage[0], storage[1]); + } else { + if (isSorted) + storage.assign({value[0], value[1]}); else - storage.append({value[1], value[0]}); - return true; + storage.assign({value[1], value[0]}); } - break; + return !isSorted; + } default: + if (!inPlace) + storage.assign(value.begin(), value.end()); // Check to see they are sorted already. bool isSorted = llvm::is_sorted(value, [](NamedAttribute l, NamedAttribute r) { @@ -133,8 +143,6 @@ }); if (!isSorted) { // If not, do a general sort. - if (!inPlace) - storage.append(value.begin(), value.end()); llvm::array_pod_sort(storage.begin(), storage.end(), compareNamedAttributes); value = storage; @@ -151,31 +159,21 @@ return false; } -/// Sorts the NamedAttributes in the array ordered by name as expected by -/// getWithSorted. -/// Requires: uniquely named attributes. -void DictionaryAttr::sort(SmallVectorImpl &array) { - dictionaryAttrSort(array, array); +bool DictionaryAttr::sort(ArrayRef value, + SmallVectorImpl &storage) { + return dictionaryAttrSort(value, storage); } -DictionaryAttr DictionaryAttr::get(ArrayRef value, - MLIRContext *context) { - assert(llvm::all_of(value, - [](const NamedAttribute &attr) { return attr.second; }) && - "value cannot have null entries"); - - // We need to sort the element list to canonicalize it. - SmallVector storage; - if (dictionaryAttrSort(value, storage)) - value = storage; - - return Base::get(context, StandardAttributes::Dictionary, value); +bool DictionaryAttr::sortInPlace(SmallVectorImpl &array) { + return dictionaryAttrSort(array, array); } /// Construct a dictionary with an array of values that is known to already be /// sorted by name and uniqued. DictionaryAttr DictionaryAttr::getWithSorted(ArrayRef value, MLIRContext *context) { + if (value.empty()) + return get(value, context); // Ensure that the attribute elements are unique and sorted. assert(llvm::is_sorted(value, [](NamedAttribute l, NamedAttribute r) { @@ -207,7 +205,8 @@ /// Return the specified named attribute if present, None otherwise. Optional DictionaryAttr::getNamed(StringRef name) const { ArrayRef values = getValue(); - auto it = llvm::lower_bound(values, name, compareNamedAttributeWithName); + const auto *it = + llvm::lower_bound(values, name, compareNamedAttributeWithName); return it != values.end() && it->first == name ? *it : Optional(); } @@ -1193,6 +1192,15 @@ setAttrs(attributes); } +/// Return the underlying dictionary attribute. +DictionaryAttr +MutableDictionaryAttr::getDictionary(MLIRContext *context) const { + // Construct empty DictionaryAttr if needed. + if (!attrs) + return DictionaryAttr::get({}, context); + return attrs; +} + ArrayRef MutableDictionaryAttr::getAttrs() const { return attrs ? attrs.getValue() : llvm::None; } @@ -1232,7 +1240,7 @@ // Look for an existing value for the given name, and set it in-place. ArrayRef values = getAttrs(); - auto it = llvm::find_if( + const auto *it = llvm::find_if( values, [name](NamedAttribute attr) { return attr.first == name; }); if (it != values.end()) { // Bail out early if the value is the same as what we already have. diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp --- a/mlir/lib/IR/Function.cpp +++ b/mlir/lib/IR/Function.cpp @@ -57,7 +57,7 @@ assert(type.getNumInputs() == argAttrs.size()); SmallString<8> argAttrName; for (unsigned i = 0, e = type.getNumInputs(); i != e; ++i) - if (auto argDict = argAttrs[i].getDictionary()) + if (auto argDict = argAttrs[i].getDictionary(builder.getContext())) result.addAttribute(getArgAttrName(i, argAttrName), argDict); } diff --git a/mlir/lib/IR/FunctionImplementation.cpp b/mlir/lib/IR/FunctionImplementation.cpp --- a/mlir/lib/IR/FunctionImplementation.cpp +++ b/mlir/lib/IR/FunctionImplementation.cpp @@ -17,8 +17,7 @@ parseArgumentList(OpAsmParser &parser, bool allowVariadic, SmallVectorImpl &argTypes, SmallVectorImpl &argNames, - SmallVectorImpl> &argAttrs, - bool &isVariadic) { + SmallVectorImpl &argAttrs, bool &isVariadic) { if (parser.parseLParen()) return failure(); @@ -54,7 +53,7 @@ argTypes.push_back(argumentType); // Parse any argument attributes. - SmallVector attrs; + NamedAttrList attrs; if (parser.parseOptionalAttrDict(attrs)) return failure(); argAttrs.push_back(attrs); @@ -90,9 +89,9 @@ /// function-result-list-no-parens ::= function-result (`,` function-result)* /// function-result ::= type attribute-dict? /// -static ParseResult parseFunctionResultList( - OpAsmParser &parser, SmallVectorImpl &resultTypes, - SmallVectorImpl> &resultAttrs) { +static ParseResult +parseFunctionResultList(OpAsmParser &parser, SmallVectorImpl &resultTypes, + SmallVectorImpl &resultAttrs) { if (failed(parser.parseOptionalLParen())) { // We already know that there is no `(`, so parse a type. // Because there is no `(`, it cannot be a function type. @@ -127,10 +126,9 @@ ParseResult mlir::impl::parseFunctionSignature( OpAsmParser &parser, bool allowVariadic, SmallVectorImpl &argNames, - SmallVectorImpl &argTypes, - SmallVectorImpl> &argAttrs, bool &isVariadic, - SmallVectorImpl &resultTypes, - SmallVectorImpl> &resultAttrs) { + SmallVectorImpl &argTypes, SmallVectorImpl &argAttrs, + bool &isVariadic, SmallVectorImpl &resultTypes, + SmallVectorImpl &resultAttrs) { if (parseArgumentList(parser, allowVariadic, argTypes, argNames, argAttrs, isVariadic)) return failure(); @@ -139,10 +137,9 @@ return success(); } -void mlir::impl::addArgAndResultAttrs( - Builder &builder, OperationState &result, - ArrayRef> argAttrs, - ArrayRef> resultAttrs) { +void mlir::impl::addArgAndResultAttrs(Builder &builder, OperationState &result, + ArrayRef argAttrs, + ArrayRef resultAttrs) { // Add the attributes to the function arguments. SmallString<8> attrNameBuf; for (unsigned i = 0, e = argAttrs.size(); i != e; ++i) @@ -164,8 +161,8 @@ bool allowVariadic, mlir::impl::FuncTypeBuilder funcTypeBuilder) { SmallVector entryArgs; - SmallVector, 4> argAttrs; - SmallVector, 4> resultAttrs; + SmallVector argAttrs; + SmallVector resultAttrs; SmallVector argTypes; SmallVector resultTypes; auto &builder = parser.getBuilder(); diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -328,6 +328,7 @@ BoolAttr falseAttr, trueAttr; UnitAttr unitAttr; UnknownLoc unknownLocAttr; + DictionaryAttr emptyDictionaryAttr; public: MLIRContextImpl() : identifiers(identifierAllocator) {} @@ -388,6 +389,9 @@ /// Unknown Location Attribute. impl->unknownLocAttr = AttributeUniquer::get( this, StandardAttributes::UnknownLocation); + /// The empty dictionary attribute. + impl->emptyDictionaryAttr = AttributeUniquer::get( + this, StandardAttributes::Dictionary, ArrayRef{}); } MLIRContext::~MLIRContext() {} @@ -742,6 +746,22 @@ return context->getImpl().unknownLocAttr; } +DictionaryAttr DictionaryAttr::get(ArrayRef value, + MLIRContext *context) { + if (value.empty()) + return context->getImpl().emptyDictionaryAttr; + assert(llvm::all_of(value, + [](const NamedAttribute &attr) { return attr.second; }) && + "value cannot have null entries"); + + // We need to sort the element list to canonicalize it. + SmallVector storage; + if (DictionaryAttr::sort(value, storage)) + value = storage; + + return Base::get(context, StandardAttributes::Dictionary, value); +} + //===----------------------------------------------------------------------===// // AffineMap uniquing //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -70,9 +70,9 @@ /// Create a new Operation from operation state. Operation *Operation::create(const OperationState &state) { - return Operation::create( - state.location, state.name, state.types, state.operands, - MutableDictionaryAttr(state.attributes), state.successors, state.regions); + return Operation::create(state.location, state.name, state.types, + state.operands, state.attributes, state.successors, + state.regions); } /// Create a new Operation with the specific fields. diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -18,6 +18,141 @@ #include "mlir/IR/StandardTypes.h" using namespace mlir; +//===----------------------------------------------------------------------===// +// NamedAttrList +//===----------------------------------------------------------------------===// + +NamedAttrList::NamedAttrList(ArrayRef attributes) { + assign(attributes.begin(), attributes.end()); +} + +NamedAttrList::NamedAttrList(const_iterator in_start, const_iterator in_end) { + assign(in_start, in_end); +} + +ArrayRef NamedAttrList::getAttrs() const { return attrs; } + +DictionaryAttr NamedAttrList::getDictionary(MLIRContext *context) const { + if (!dictionarySorted.getInt()) { + DictionaryAttr::sortInPlace(attrs); + dictionarySorted.setInt(true); + dictionarySorted.setPointer(nullptr); + } + if (!dictionarySorted.getPointer()) + dictionarySorted.setPointer(DictionaryAttr::getWithSorted(attrs, context)); + return dictionarySorted.getPointer().cast(); +} + +NamedAttrList::operator MutableDictionaryAttr() const { + if (attrs.empty()) + return MutableDictionaryAttr(); + return getDictionary(attrs.front().second.getContext()); +} + +/// Add an attribute with the specified name. +void NamedAttrList::append(StringRef name, Attribute attr) { + append(Identifier::get(name, attr.getContext()), attr); +} + +/// Add an attribute with the specified name. +void NamedAttrList::append(Identifier name, Attribute attr) { + push_back({name, attr}); +} + +/// Add an array of named attributes. +void NamedAttrList::append(ArrayRef newAttributes) { + append(newAttributes.begin(), newAttributes.end()); +} + +/// Add a range of named attributes. +void NamedAttrList::append(const_iterator in_start, const_iterator in_end) { + // TODO: expand to handle case where values appended are in order & after + // end of current list. + dictionarySorted.setInt(false); + dictionarySorted.setPointer(nullptr); + attrs.append(in_start, in_end); +} + +/// Replaces the attributes with new list of attributes. +void NamedAttrList::assign(const_iterator in_start, const_iterator in_end) { + DictionaryAttr::sort(ArrayRef{in_start, in_end}, attrs); + dictionarySorted.setInt(true); + dictionarySorted.setPointer(nullptr); +} + +void NamedAttrList::push_back(NamedAttribute newAttribute) { + dictionarySorted.setInt( + dictionarySorted.getInt() && + (attrs.empty() || + strcmp(attrs.back().first.data(), newAttribute.first.data()) < 0)); + dictionarySorted.setPointer(nullptr); + attrs.push_back(newAttribute); +} + +/// Return the specified attribute if present, null otherwise. +Attribute NamedAttrList::get(StringRef name) const { + const auto *it = llvm::find_if( + attrs, [name](NamedAttribute attr) { return attr.first == name; }); + return it != attrs.end() ? it->second : nullptr; +} + +/// Return the specified attribute if present, null otherwise. +Attribute NamedAttrList::get(Identifier name) const { + const auto *it = llvm::find_if( + attrs, [name](NamedAttribute attr) { return attr.first == name; }); + return it != attrs.end() ? it->second : nullptr; +} + +/// Return the specified named attribute if present, None otherwise. +Optional NamedAttrList::getNamed(StringRef name) const { + const auto *it = llvm::find_if( + attrs, [name](NamedAttribute attr) { return attr.first == name; }); + return it != attrs.end() ? *it : Optional(); +} +Optional NamedAttrList::getNamed(Identifier name) const { + const auto *it = llvm::find_if( + attrs, [name](NamedAttribute attr) { return attr.first == name; }); + return it != attrs.end() ? *it : Optional(); +} + +/// If the an attribute exists with the specified name, change it to the new +/// value. Otherwise, add a new attribute with the specified name/value. +void NamedAttrList::set(Identifier name, Attribute value) { + assert(value && "attributes may never be null"); + + // Look for an existing value for the given name, and set it in-place. + auto *it = llvm::find_if( + attrs, [name](NamedAttribute attr) { return attr.first == name; }); + if (it != attrs.end()) { + // Bail out early if the value is the same as what we already have. + if (it->second == value) + return; + dictionarySorted.setPointer(nullptr); + it->second = value; + return; + } + + // Otherwise, insert the new attribute into its sorted position. + it = llvm::lower_bound( + attrs, name, [](const NamedAttribute &attr, StringRef name) { + return strncmp(attr.first.data(), name.data(), name.size()) < 0; + }); + dictionarySorted.setPointer(nullptr); + attrs.insert(it, {name, value}); +} +void NamedAttrList::set(StringRef name, Attribute value) { + assert(value && "setting null attribute not supported"); + return set(mlir::Identifier::get(name, value.getContext()), value); +} + +NamedAttrList & +NamedAttrList::operator=(const SmallVectorImpl &rhs) { + assign(rhs.begin(), rhs.end()); + return *this; +} + +NamedAttrList::operator ArrayRef() const { return attrs; } + //===----------------------------------------------------------------------===// // OperationState //===----------------------------------------------------------------------===// @@ -133,7 +268,7 @@ // Shift all operands down if the operand to remove is not at the end. if (start != storage.numOperands) { - auto indexIt = std::next(operands.begin(), start); + auto *indexIt = std::next(operands.begin(), start); std::rotate(indexIt, std::next(indexIt, length), operands.end()); } for (unsigned i = 0; i != length; ++i) @@ -417,7 +552,8 @@ // - Operation Name // - Attributes llvm::hash_code hash = llvm::hash_combine( - op->getName(), op->getMutableAttrDict().getDictionary()); + op->getName(), + op->getAttrs().empty() ? nullptr : op->getAttrDictionary()); // - Result Types ArrayRef resultTypes = op->getResultTypes(); diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp --- a/mlir/lib/IR/SymbolTable.cpp +++ b/mlir/lib/IR/SymbolTable.cpp @@ -385,9 +385,9 @@ Operation *op, function_ref)> callback) { // Check to see if the operation has any attributes. - DictionaryAttr attrDict = op->getMutableAttrDict().getDictionary(); - if (!attrDict) + if (op->getMutableAttrDict().empty()) return WalkResult::advance(); + DictionaryAttr attrDict = op->getAttrDictionary(); // A worklist of a container attribute and the current index into the held // attribute list. @@ -803,7 +803,7 @@ // Generate a new attribute dictionary for the current operation by replacing // references to the old symbol. auto generateNewAttrDict = [&] { - auto oldDict = curOp->getMutableAttrDict().getDictionary(); + auto oldDict = curOp->getAttrDictionary(); auto newDict = rebuildAttrAfterRAUW(oldDict, accessChains, /*depth=*/0); return newDict.cast(); }; diff --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp --- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp +++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp @@ -24,11 +24,11 @@ LogicalResult mlir::detail::inferReturnTensorTypes( function_ref location, ValueRange operands, - ArrayRef attributes, RegionRange regions, + DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &retComponents)> componentTypeFn, MLIRContext *context, Optional location, ValueRange operands, - ArrayRef attributes, RegionRange regions, + DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { SmallVector retComponents; if (failed(componentTypeFn(context, location, operands, attributes, regions, @@ -49,9 +49,9 @@ LogicalResult mlir::detail::verifyInferredResultTypes(Operation *op) { SmallVector inferredReturnTypes; auto retTypeFn = cast(op); - if (failed(retTypeFn.inferReturnTypes(op->getContext(), op->getLoc(), - op->getOperands(), op->getAttrs(), - op->getRegions(), inferredReturnTypes))) + if (failed(retTypeFn.inferReturnTypes( + op->getContext(), op->getLoc(), op->getOperands(), + op->getAttrDictionary(), op->getRegions(), inferredReturnTypes))) return failure(); if (!retTypeFn.isCompatibleReturnTypes(inferredReturnTypes, op->getResultTypes())) diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -275,7 +275,7 @@ Attribute parseAttribute(Type type = {}); /// Parse an attribute dictionary. - ParseResult parseAttributeDict(SmallVectorImpl &attributes); + ParseResult parseAttributeDict(NamedAttrList &attributes); /// Parse an extended attribute. Attribute parseExtendedAttr(Type type); @@ -1569,10 +1569,10 @@ // Parse a dictionary attribute. case Token::l_brace: { - SmallVector elements; + NamedAttrList elements; if (parseAttributeDict(elements)) return nullptr; - return builder.getDictionaryAttr(elements); + return elements.getDictionary(getContext()); } // Parse an extended attribute, i.e. alias or dialect attribute. @@ -1671,8 +1671,7 @@ /// | `{` attribute-entry (`,` attribute-entry)* `}` /// attribute-entry ::= (bare-id | string-literal) `=` attribute-value /// -ParseResult -Parser::parseAttributeDict(SmallVectorImpl &attributes) { +ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) { if (parseToken(Token::l_brace, "expected '{' in attribute dictionary")) return failure(); @@ -1701,7 +1700,6 @@ auto attr = parseAttribute(); if (!attr) return failure(); - attributes.push_back({*nameId, attr}); return success(); }; @@ -4217,7 +4215,7 @@ /// also adds the attribute to the specified attribute list with the specified /// name. ParseResult parseAttribute(Attribute &result, Type type, StringRef attrName, - SmallVectorImpl &attrs) override { + NamedAttrList &attrs) override { result = parser.parseAttribute(type); if (!result) return failure(); @@ -4227,8 +4225,7 @@ } /// Parse a named dictionary into 'result' if it is present. - ParseResult - parseOptionalAttrDict(SmallVectorImpl &result) override { + ParseResult parseOptionalAttrDict(NamedAttrList &result) override { if (parser.getToken().isNot(Token::l_brace)) return success(); return parser.parseAttributeDict(result); @@ -4236,8 +4233,7 @@ /// Parse a named dictionary into 'result' if the `attributes` keyword is /// present. - ParseResult parseOptionalAttrDictWithKeyword( - SmallVectorImpl &result) override { + ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result) override { if (failed(parseOptionalKeyword("attributes"))) return success(); return parser.parseAttributeDict(result); @@ -4285,9 +4281,8 @@ /// Parse an optional @-identifier and store it (without the '@' symbol) in a /// string attribute named 'attrName'. - ParseResult - parseOptionalSymbolName(StringAttr &result, StringRef attrName, - SmallVectorImpl &attrs) override { + ParseResult parseOptionalSymbolName(StringAttr &result, StringRef attrName, + NamedAttrList &attrs) override { Token atToken = parser.getToken(); if (atToken.isNot(Token::at_identifier)) return failure(); @@ -4435,7 +4430,7 @@ /// Parse an AffineMap of SSA ids. ParseResult parseAffineMapOfSSAIds(SmallVectorImpl &operands, Attribute &mapAttr, StringRef attrName, - SmallVectorImpl &attrs, + NamedAttrList &attrs, Delimiter delimiter) override { SmallVector dimOperands; SmallVector symOperands; diff --git a/mlir/lib/Pass/IRPrinting.cpp b/mlir/lib/Pass/IRPrinting.cpp --- a/mlir/lib/Pass/IRPrinting.cpp +++ b/mlir/lib/Pass/IRPrinting.cpp @@ -33,9 +33,9 @@ // - Operation pointer addDataToHash(hasher, op); // - Attributes - addDataToHash( - hasher, - op->getMutableAttrDict().getDictionary().getAsOpaquePointer()); + addDataToHash(hasher, op->getAttrs().empty() + ? nullptr + : op->getAttrDictionary().getAsOpaquePointer()); // - Blocks in Regions for (Region ®ion : op->getRegions()) { for (Block &block : region) { diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -325,7 +325,7 @@ LogicalResult mlir::OpWithInferTypeInterfaceOp::inferReturnTypes( MLIRContext *, Optional location, ValueRange operands, - ArrayRef attributes, RegionRange regions, + DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { if (operands[0].getType() != operands[1].getType()) { return emitOptionalError(location, "operand type mismatch ", @@ -338,7 +338,7 @@ LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( MLIRContext *context, Optional location, ValueRange operands, - ArrayRef attributes, RegionRange regions, + DictionaryAttr attributes, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { // Create return type consisting of the last element of the first operand. auto operandType = *operands.getTypes().begin(); diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -72,9 +72,9 @@ for (int j = 0; j < e; ++j) { std::array values = {{fop.getArgument(i), fop.getArgument(j)}}; SmallVector inferredReturnTypes; - if (succeeded(OpTy::inferReturnTypes(context, llvm::None, values, - op->getAttrs(), op->getRegions(), - inferredReturnTypes))) { + if (succeeded(OpTy::inferReturnTypes( + context, llvm::None, values, op->getAttrDictionary(), + op->getRegions(), inferredReturnTypes))) { OperationState state(location, OpTy::getOperationName()); // TODO(jpienaar): Expand to regions. OpTy::build(b, state, values, op->getAttrs()); diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -775,7 +775,8 @@ body << formatv(R"( SmallVector inferredReturnTypes; if (succeeded({0}::inferReturnTypes(odsBuilder.getContext(), - {1}.location, {1}.operands, {1}.attributes, + {1}.location, {1}.operands, + {1}.attributes.getDictionary({1}.getContext()), /*regions=*/{{}, inferredReturnTypes))) {1}.addTypes(inferredReturnTypes); else @@ -867,13 +868,48 @@ opClass.newMethod("void", "build", formatv(params, builderOpState).str(), OpMethod::MP_Static); auto &body = m.body(); + + int numResults = op.getNumResults(); + int numVariadicResults = op.getNumVariableLengthResults(); + int numNonVariadicResults = numResults - numVariadicResults; + + int numOperands = op.getNumOperands(); + int numVariadicOperands = op.getNumVariableLengthOperands(); + int numNonVariadicOperands = numOperands - numVariadicOperands; + + // Operands + if (numVariadicOperands == 0 || numNonVariadicOperands != 0) + body << " assert(operands.size()" + << (numVariadicOperands != 0 ? " >= " : " == ") + << numNonVariadicOperands + << "u && \"mismatched number of parameters\");\n"; + body << " " << builderOpState << ".addOperands(operands);\n"; + body << " " << builderOpState << ".addAttributes(attributes);\n"; + + // Create the correct number of regions + if (int numRegions = op.getNumRegions()) { + body << llvm::formatv( + " for (unsigned i = 0; i != {0}; ++i)\n", + (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions))); + body << " (void)" << builderOpState << ".addRegion();\n"; + } + + // Result types body << formatv(R"( SmallVector inferredReturnTypes; if (succeeded({0}::inferReturnTypes(odsBuilder.getContext(), - {1}.location, operands, attributes, - /*regions=*/{{}, inferredReturnTypes))) - build(odsBuilder, odsState, inferredReturnTypes, operands, attributes); - else + {1}.location, operands, + {1}.attributes.getDictionary({1}.getContext()), + /*regions=*/{{}, inferredReturnTypes))) {{)", + opClass.getClassName(), builderOpState); + if (numVariadicResults == 0 || numNonVariadicResults != 0) + body << " assert(inferredReturnTypes.size()" + << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults + << "u && \"mismatched number of return types\");\n"; + body << " " << builderOpState << ".addTypes(inferredReturnTypes);"; + + body << formatv(R"( + } else llvm::report_fatal_error("Failed to infer result type(s).");)", opClass.getClassName(), builderOpState); } diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -372,7 +372,7 @@ const char *const enumAttrParserCode = R"( { StringAttr attrVal; - SmallVector attrStorage; + NamedAttrList attrStorage; auto loc = parser.getCurrentLocation(); if (parser.parseAttribute(attrVal, parser.getBuilder().getNoneType(), "{0}", attrStorage))