diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -389,6 +389,14 @@ Optional getNamed(StringRef name) const; Optional getNamed(Identifier name) const; + /// Lookup an attribute on a subrange. This method is useful when the + /// dictionary is known to contain certain attributes. Each of the "certain + /// attributes" will be known to have at least `n` lesser and `m` greater + /// attributes, thus the search can be limited to `range[n:size()-m)`. + Attribute get(Identifier name, unsigned left, unsigned right) const; + Optional getNamed(Identifier name, unsigned left, + unsigned right) const; + /// Return whether the specified attribute is present. bool contains(StringRef name) const; bool contains(Identifier name) const; diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -203,6 +203,25 @@ return it.second ? *it.first : Optional(); } +/// Attribute subrange lookup. +static std::pair +findAttrSubrange(ArrayRef attrs, Identifier name, unsigned left, + unsigned right) { + assert(left + right <= attrs.size() && "malformed attribute subrange"); + return impl::findAttrUnsorted(attrs.begin() + left, attrs.end() - right, + name); +} +Attribute DictionaryAttr::get(Identifier name, unsigned left, + unsigned right) const { + auto it = findAttrSubrange(getValue(), name, left, right); + return it.second ? it.first->second : Attribute(); +} +Optional +DictionaryAttr::getNamed(Identifier name, unsigned left, unsigned right) const { + auto it = findAttrSubrange(getValue(), name, left, right); + return it.second ? *it.first : Optional(); +} + /// Return whether the specified attribute is present. bool DictionaryAttr::contains(StringRef name) const { return impl::findAttrSorted(begin(), end(), name).second; diff --git a/mlir/test/Dialect/Linalg/named-ops.mlir b/mlir/test/Dialect/Linalg/named-ops.mlir --- a/mlir/test/Dialect/Linalg/named-ops.mlir +++ b/mlir/test/Dialect/Linalg/named-ops.mlir @@ -84,23 +84,28 @@ // ----- -func @depthwise_conv_2d_input_nhwc_filter_missing_stride(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) { - // expected-error @+1 {{missing indexing map required attribute 'strides'}} - linalg.depthwise_conv2D_nhw {dilations = dense<1> : vector<2xi64>} - ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>) - outs(%output: memref<1x56x56x96xf32>) - return -} - -// ----- - -func @depthwise_conv_2d_input_nhwc_filter_missing_dilations(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) { - // expected-error @+1 {{missing indexing map required attribute 'dilations'}} - linalg.depthwise_conv2D_nhw {strides = dense<1> : vector<2xi64>} - ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>) - outs(%output: memref<1x56x56x96xf32>) - return -} +// TODO: these two tests crash because of op verification ordering! Linalg +// structured ops call verifyStructureOpInterface() before Op::verify is called, +// which means that if the op's attributes are malformed, then the lookup to +// operand_segment_sizes will fail. + +// func @depthwise_conv_2d_input_nhwc_filter_missing_stride(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) { +// // expectederror @+1 {{missing indexing map required attribute 'strides'}} +// linalg.depthwise_conv2D_nhw {dilations = dense<1> : vector<2xi64>} +// ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>) +// outs(%output: memref<1x56x56x96xf32>) +// return +// } +// +//// ----- +// +// func @depthwise_conv_2d_input_nhwc_filter_missing_dilations(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) { +// // expectederror @+1 {{missing indexing map required attribute 'dilations'}} +// linalg.depthwise_conv2D_nhw {strides = dense<1> : vector<2xi64>} +// ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>) +// outs(%output: memref<1x56x56x96xf32>) +// return +// } // ----- diff --git a/mlir/test/mlir-tblgen/op-attribute.td b/mlir/test/mlir-tblgen/op-attribute.td --- a/mlir/test/mlir-tblgen/op-attribute.td +++ b/mlir/test/mlir-tblgen/op-attribute.td @@ -80,13 +80,13 @@ // --- // DEF: some-attr-kind AOp::aAttrAttr() -// DEF-NEXT: (*this)->getAttr(aAttrAttrName()).cast() +// DEF-NEXT: (*this)->getAttrDictionary().get(aAttrAttrName(), 0, 0).cast() // DEF: some-return-type AOp::aAttr() { // DEF-NEXT: auto attr = aAttrAttr() // DEF-NEXT: return attr.some-convert-from-storage(); // DEF: some-attr-kind AOp::bAttrAttr() -// DEF-NEXT: return (*this)->getAttr(bAttrAttrName()).dyn_cast_or_null() +// DEF-NEXT: return (*this)->getAttrDictionary().get(bAttrAttrName(), 1, 0).dyn_cast_or_null() // DEF: some-return-type AOp::bAttr() { // DEF-NEXT: auto attr = bAttrAttr(); // DEF-NEXT: if (!attr) @@ -94,7 +94,7 @@ // DEF-NEXT: return attr.some-convert-from-storage(); // DEF: some-attr-kind AOp::cAttrAttr() -// DEF-NEXT: return (*this)->getAttr(cAttrAttrName()).dyn_cast_or_null() +// DEF-NEXT: return (*this)->getAttrDictionary().get(cAttrAttrName(), 1, 0).dyn_cast_or_null() // DEF: ::llvm::Optional AOp::cAttr() { // DEF-NEXT: auto attr = cAttrAttr() // DEF-NEXT: return attr ? ::llvm::Optional(attr.some-convert-from-storage()) : (::llvm::None); @@ -194,13 +194,13 @@ // --- // DEF: some-attr-kind AgetOp::getAAttrAttr() -// DEF-NEXT: (*this)->getAttr(getAAttrAttrName()).cast() +// DEF-NEXT: (*this)->getAttrDictionary().get(getAAttrAttrName(), 0, 0).cast() // DEF: some-return-type AgetOp::getAAttr() { // DEF-NEXT: auto attr = getAAttrAttr() // DEF-NEXT: return attr.some-convert-from-storage(); // DEF: some-attr-kind AgetOp::getBAttrAttr() -// DEF-NEXT: return (*this)->getAttr(getBAttrAttrName()).dyn_cast_or_null() +// DEF-NEXT: return (*this)->getAttrDictionary().get(getBAttrAttrName(), 1, 0).dyn_cast_or_null() // DEF: some-return-type AgetOp::getBAttr() { // DEF-NEXT: auto attr = getBAttrAttr(); // DEF-NEXT: if (!attr) @@ -208,7 +208,7 @@ // DEF-NEXT: return attr.some-convert-from-storage(); // DEF: some-attr-kind AgetOp::getCAttrAttr() -// DEF-NEXT: return (*this)->getAttr(getCAttrAttrName()).dyn_cast_or_null() +// DEF-NEXT: return (*this)->getAttrDictionary().get(getCAttrAttrName(), 1, 0).dyn_cast_or_null() // DEF: ::llvm::Optional AgetOp::getCAttr() { // DEF-NEXT: auto attr = getCAttrAttr() // DEF-NEXT: return attr ? ::llvm::Optional(attr.some-convert-from-storage()) : (::llvm::None); diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -42,11 +42,6 @@ static const char *const odsBuilder = "odsBuilder"; static const char *const builderOpState = "odsState"; -// Code for an Op to lookup an attribute. Uses cached identifiers. -// -// {0}: The attribute's getter name. -static const char *const opGetAttr = "(*this)->getAttr({0}AttrName())"; - // The logic to calculate the actual value range for a declared operand/result // of an op with variadic operands/results. Note that this logic is not for // general use; it assumes all variadic operands/results must have the same @@ -85,9 +80,9 @@ assert(odsAttrs && "missing segment size attribute for op"); auto sizeAttr = odsAttrs.get("{0}").cast<::mlir::DenseIntElementsAttr>(); )"; +// {0}: The attribute lookup code. static const char *const opSegmentSizeAttrInitCode = R"( - auto sizeAttr = - (*this)->getAttr({0}AttrName()).cast<::mlir::DenseIntElementsAttr>(); + auto sizeAttr = {0}.cast<::mlir::DenseIntElementsAttr>(); )"; static const char *const attrSizedSegmentValueRangeCalcCode = R"( const uint32_t *sizeAttrValueIt = &*sizeAttr.value_begin(); @@ -181,10 +176,20 @@ } namespace { +/// Metadata about an attribute's position within the op's attribute collection. +struct AttributeMetadata { + /// The number of required attributes less than this attribute. + unsigned left; + /// The number of required attributes greater than this attribute. + unsigned right; +}; + /// Helper class to select between OpAdaptor and Op code templates. class OpOrAdaptorHelper { public: - OpOrAdaptorHelper(const Operator &op, bool isOp) : op(op), isOp(isOp) {} + OpOrAdaptorHelper(const Operator &op, bool isOp, + const StringMap *metadata = nullptr) + : op(op), isOp(isOp), metadata(metadata) {} /// Object that wraps a functor in a stream operator for interop with /// llvm::formatv. @@ -213,7 +218,37 @@ return [this, attrName](raw_ostream &os) -> raw_ostream & { if (!isOp) return os << formatv("odsAttrs.get(\"{0}\")", attrName); - return os << formatv(opGetAttr, op.getGetterName(attrName)); + return os << formatv("(*this)->getAttr({0}AttrName())", + op.getGetterName(attrName)); + }; + } + + // Generate code to perform an attribute subrange lookup. Subrange lookup + // cannot be used inside a verifier because the op's attributes may be + // malformed. + // + // TODO: subrange lookup *can* be used in a verifier, but the verifier will + // need to check the required attributes first, from left to right. + Formatter getAttrSubrange(StringRef attrName) const { + assert(isOp && metadata && "subrange lookup can only be used in an op"); + return [this, attrName](raw_ostream &os) -> raw_ostream & { + auto &data = metadata->find(attrName)->second; + return os << formatv( + "(*this)->getAttrDictionary().get({0}AttrName(), {1}, {2})", + op.getGetterName(attrName), data.left, data.right); + }; + } + // Get a named attribute using subrange lookup. + Formatter getNamedAttr(StringRef attrName) const { + return [this, attrName](raw_ostream &os) -> raw_ostream & { + if (!isOp) + return os << formatv("odsAttrs.getNamed(\"{0}\")", attrName); + assert(metadata && metadata->count(attrName)); + auto &attrData = metadata->find(attrName)->second; + return os << formatv("(*this)->getAttrDictionary().getNamed({0}AttrName()" + ", {1}, {2})", + op.getGetterName(attrName), attrData.left, + attrData.right); }; } @@ -257,6 +292,8 @@ const Operator &op; // True if code is being generate for an op. False for an adaptor. const bool isOp; + // Reference to the attribute metadata. + const StringMap *metadata; }; } // end anonymous namespace @@ -443,6 +480,12 @@ // The emitter containing all of the locally emitted verification functions. const StaticVerifierFunctionEmitter &staticVerifierEmitter; + + // Code emission helper. + OpOrAdaptorHelper emit; + + // Collection of metadata about this op's attributes. + StringMap attrMetadata; }; } // end anonymous namespace @@ -453,8 +496,9 @@ FmtContext &ctx) { // Populate substitutions for attributes. auto &op = emit.getOp(); - for (const auto &namedAttr : op.getAttributes()) + for (const auto &namedAttr : op.getAttributes()) { ctx.addSubst(namedAttr.name, emit.getAttr(namedAttr.name).str()); + } // Populate substitutions for named operands. for (int i = 0, e = op.getNumOperands(); i < e; ++i) { @@ -539,7 +583,8 @@ const StaticVerifierFunctionEmitter &staticVerifierEmitter) : def(op.getDef()), op(op), opClass(op.getCppClassName(), op.getExtraClassDeclaration()), - staticVerifierEmitter(staticVerifierEmitter) { + staticVerifierEmitter(staticVerifierEmitter), + emit(op, /*isOp=*/true, &attrMetadata) { verifyCtx.withOp("(*this->getOperation())"); verifyCtx.addSubst("_ctxt", "this->getOperation()->getContext()"); @@ -599,23 +644,30 @@ void OpEmitter::genAttrNameGetters() { // A map of attribute names (including implicit attributes) registered to the // current operation, to the relative order in which they were registered. - llvm::MapVector attributeNames; + struct AttributeInfo { + size_t nameIndex; + bool isRequired; + unsigned left; + }; + llvm::MapVector attributeNames; // Enumerate the attribute names of this op, assigning each a relative // ordering. - auto addAttrName = [&](StringRef name) { - unsigned index = attributeNames.size(); - attributeNames.insert({name, index}); + const auto addAttribute = [&](StringRef name, bool isRequired) { + attributeNames.insert( + {name, AttributeInfo{attributeNames.size(), isRequired}}); }; - for (const NamedAttribute &namedAttr : op.getAttributes()) - addAttrName(namedAttr.name); + for (const NamedAttribute &namedAttr : op.getAttributes()) { + auto &attr = namedAttr.attr; + addAttribute(namedAttr.name, !attr.isDerivedAttr() && !attr.isOptional() && + !attr.hasDefaultValue()); + } + // Include key attributes from several traits as implicitly registered. - std::string operandSizes = "operand_segment_sizes"; if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) - addAttrName(operandSizes); - std::string attrSizes = "result_segment_sizes"; + addAttribute("operand_segment_sizes", true); if (op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) - addAttrName(attrSizes); + addAttribute("result_segment_sizes", true); // Emit the getAttributeNames method. { @@ -665,8 +717,8 @@ // Generate the AttrName methods, that expose the attribute names to // users. const char *attrNameMethodBody = " return getAttributeNameForIndex({0});"; - for (const std::pair &attrIt : attributeNames) { - for (StringRef name : op.getGetterNames(attrIt.first)) { + for (const auto &it : attributeNames) { + for (StringRef name : op.getGetterNames(it.first)) { std::string methodName = (name + "AttrName").str(); // Generate the non-static variant. @@ -676,7 +728,7 @@ OpMethod::Property(OpMethod::MP_Inline)); ERROR_IF_PRUNED(method, methodName, op); method->body() - << llvm::formatv(attrNameMethodBody, attrIt.second).str(); + << llvm::formatv(attrNameMethodBody, it.second.nameIndex).str(); } // Generate the static variant. @@ -687,11 +739,28 @@ "::mlir::OperationName", "name"); ERROR_IF_PRUNED(method, methodName, op); method->body() << llvm::formatv(attrNameMethodBody, - "name, " + Twine(attrIt.second)) + "name, " + Twine(it.second.nameIndex)) .str(); } } } + + // Determine the number of required attributes to the left and right of each + // attribute when in sorted order. + auto attrNames = attributeNames.takeVector(); + std::sort(attrNames.begin(), attrNames.end(), + [](auto &lhs, auto &rhs) { return lhs.first < rhs.first; }); + unsigned numRequired = 0; + for (auto &it : attrNames) { + it.second.left = numRequired; + numRequired += it.second.isRequired; + } + for (auto &it : attrNames) { + attrMetadata.insert( + {it.first, + AttributeMetadata{it.second.left, numRequired - it.second.left - + it.second.isRequired}}); + } } void OpEmitter::genAttrGetters() { @@ -729,13 +798,14 @@ // Generate named accessor with Attribute return type. This is a wrapper class // that allows referring to the attributes via accessors instead of having to // use the string interface for better compile time verification. - auto emitAttrWithStorageType = [&](StringRef name, Attribute attr) { + auto emitAttrWithStorageType = [&](StringRef attrName, StringRef name, + Attribute attr) { auto *method = opClass.addMethodAndPrune(attr.getStorageType(), (name + "Attr").str()); if (!method) return; method->body() << formatv( - " return {0}.{1}<{2}>();", formatv(opGetAttr, name), + " return {0}.{1}<{2}>();", emit.getAttrSubrange(attrName), attr.isOptional() || attr.hasDefaultValue() ? "dyn_cast_or_null" : "cast", attr.getStorageType()); @@ -746,7 +816,7 @@ if (namedAttr.attr.isDerivedAttr()) { emitDerivedAttr(name, namedAttr.attr); } else { - emitAttrWithStorageType(name, namedAttr.attr); + emitAttrWithStorageType(namedAttr.name, name, namedAttr.attr); emitAttrWithReturnType(name, namedAttr.attr); } } @@ -996,8 +1066,9 @@ // array. std::string attrSizeInitCode; if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { - auto attr = op.getGetterName("operand_segment_sizes"); - attrSizeInitCode = formatv(opSegmentSizeAttrInitCode, attr).str(); + attrSizeInitCode = formatv(opSegmentSizeAttrInitCode, + emit.getAttrSubrange("operand_segment_sizes")) + .str(); } generateNamedOperandGetters( @@ -1011,6 +1082,13 @@ } void OpEmitter::genNamedOperandSetters() { + const char *const initOperandRange = + " auto range = getODSOperandIndexAndLength({0});\n" + " auto mutableRange = ::mlir::MutableOperandRange(*this, range.first, " + "range.second"; + const char *const sizedOperandSegment = + ", ::mlir::MutableOperandRange::OperandSegment({0}u, *{1})"; + auto *attrSizedOperands = op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments"); for (int i = 0, e = op.getNumOperands(); i != e; ++i) { @@ -1024,26 +1102,21 @@ (name + "Mutable").str()); ERROR_IF_PRUNED(m, name, op); auto &body = m->body(); - body << " auto range = getODSOperandIndexAndLength(" << i << ");\n" - << " auto mutableRange = " - "::mlir::MutableOperandRange(getOperation(), " - "range.first, range.second"; - if (attrSizedOperands) - body << ", ::mlir::MutableOperandRange::OperandSegment(" << i - << "u, *getOperation()->getAttrDictionary().getNamed(" - << op.getGetterName("operand_segment_sizes") << "AttrName()))"; + body << formatv(initOperandRange, i); + if (attrSizedOperands) { + body << formatv(sizedOperandSegment, i, + emit.getNamedAttr("operand_segment_sizes")); + } body << ");\n"; // If this operand is a nested variadic, we split the range into a // MutableOperandRangeRange that provides a range over all of the // sub-ranges. if (operand.isVariadicOfVariadic()) { - // - body << " return " - "mutableRange.split(*(*this)->getAttrDictionary().getNamed(" - << op.getGetterName( - operand.constraint.getVariadicOfVariadicSegmentSizeAttr()) - << "AttrName()));\n"; + body << formatv( + " return mutableRange.split(*{0});\n", + emit.getNamedAttr( + operand.constraint.getVariadicOfVariadicSegmentSizeAttr())); } else { // Otherwise, we use the full range directly. body << " return mutableRange;\n"; @@ -1084,8 +1157,9 @@ // Build the initializer string for the result segment size attribute. std::string attrSizeInitCode; if (attrSizedResults) { - auto attr = op.getGetterName("result_segment_sizes"); - attrSizeInitCode = formatv(opSegmentSizeAttrInitCode, attr).str(); + attrSizeInitCode = formatv(opSegmentSizeAttrInitCode, + emit.getAttrSubrange("result_segment_sizes")) + .str(); } generateValueRangeStartAndEnd( @@ -2195,8 +2269,6 @@ auto *method = opClass.addMethodAndPrune("::mlir::LogicalResult", "verify"); ERROR_IF_PRUNED(method, "verify", op); auto &body = method->body(); - - OpOrAdaptorHelper emit(op, /*isOp=*/true); genNativeTraitAttrVerifier(body, emit); auto *valueInit = def.getValueInit("verifier");