diff --git a/mlir/include/mlir/TableGen/OpClass.h b/mlir/include/mlir/TableGen/OpClass.h --- a/mlir/include/mlir/TableGen/OpClass.h +++ b/mlir/include/mlir/TableGen/OpClass.h @@ -145,10 +145,6 @@ public: explicit OpClass(StringRef name, StringRef extraClassDeclaration = ""); - // Sets whether this OpClass should generate the using directive for its - // associate operand adaptor class. - void setHasOperandAdaptorClass(bool has); - // Adds an op trait. void addTrait(Twine trait); @@ -160,7 +156,6 @@ StringRef extraClassDeclaration; SmallVector traitsVec; StringSet<> traitsSet; - bool hasOperandAdaptor; }; } // namespace tblgen diff --git a/mlir/lib/TableGen/OpClass.cpp b/mlir/lib/TableGen/OpClass.cpp --- a/mlir/lib/TableGen/OpClass.cpp +++ b/mlir/lib/TableGen/OpClass.cpp @@ -188,12 +188,7 @@ //===----------------------------------------------------------------------===// tblgen::OpClass::OpClass(StringRef name, StringRef extraClassDeclaration) - : Class(name), extraClassDeclaration(extraClassDeclaration), - hasOperandAdaptor(true) {} - -void tblgen::OpClass::setHasOperandAdaptorClass(bool has) { - hasOperandAdaptor = has; -} + : Class(name), extraClassDeclaration(extraClassDeclaration) {} void tblgen::OpClass::addTrait(Twine trait) { auto traitStr = trait.str(); @@ -207,8 +202,7 @@ os << ", " << trait; os << "> {\npublic:\n"; os << " using Op::Op;\n"; - if (hasOperandAdaptor) - os << " using OperandAdaptor = " << className << "OperandAdaptor;\n"; + os << " using OperandAdaptor = " << className << "OperandAdaptor;\n"; bool hasPrivateMethod = false; for (const auto &method : methods) { diff --git a/mlir/test/mlir-tblgen/op-decl.td b/mlir/test/mlir-tblgen/op-decl.td --- a/mlir/test/mlir-tblgen/op-decl.td +++ b/mlir/test/mlir-tblgen/op-decl.td @@ -50,12 +50,14 @@ // CHECK: class AOpOperandAdaptor { // CHECK: public: -// CHECK: AOpOperandAdaptor(ArrayRef values); +// CHECK: AOpOperandAdaptor(ArrayRef values // CHECK: ArrayRef getODSOperands(unsigned index); // CHECK: Value a(); // CHECK: ArrayRef b(); +// CHECK: IntegerAttr attr1(); +// CHECL: FloatAttr attr2(); // CHECK: private: -// CHECK: ArrayRef tblgen_operands; +// CHECK: ArrayRef odsOperands; // CHECK: }; // CHECK: class AOp : public Op::Impl, OpTrait::AtLeastNResults<1>::Impl, OpTrait::ZeroSuccessor, OpTrait::AtLeastNOperands<1>::Impl, OpTrait::IsIsolatedFromAbove @@ -90,6 +92,29 @@ // CHECK: void displayGraph(); // CHECK: }; +// Check AttrSizedOperandSegments +// --- + +def NS_AttrSizedOperandOp : NS_Op<"attr_sized_operands", + [AttrSizedOperandSegments]> { + let arguments = (ins + Variadic:$a, + Variadic:$b, + I32:$c, + Variadic:$d, + I32ElementsAttr:$operand_segment_sizes + ); +} + +// CHECK-LABEL: AttrSizedOperandOpOperandAdaptor( +// CHECK-SAME: ArrayRef values +// CHECK-SAME: DictionaryAttr attrs +// CHECK: ArrayRef a(); +// CHECK: ArrayRef b(); +// CHECK: Value c(); +// CHECK: ArrayRef d(); +// CHECK: DenseIntElementsAttr operand_segment_sizes(); + // Check op trait for different number of operands // --- @@ -150,3 +175,4 @@ // CHECK-LABEL: _BOp declarations // CHECK: class _BOp : public Op<_BOp + diff --git a/mlir/test/mlir-tblgen/op-operand.td b/mlir/test/mlir-tblgen/op-operand.td --- a/mlir/test/mlir-tblgen/op-operand.td +++ b/mlir/test/mlir-tblgen/op-operand.td @@ -15,7 +15,7 @@ // CHECK-LABEL: OpA definitions // CHECK: OpAOperandAdaptor::OpAOperandAdaptor -// CHECK-NEXT: tblgen_operands = values +// CHECK-NEXT: odsOperands = values // CHECK: void OpA::build // CHECK: Value input diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -70,13 +70,19 @@ // (variadic or not). // // {0}: The name of the attribute specifying the segment sizes. -const char *attrSizedSegmentValueRangeCalcCode = R"( +const char *adapterSegmentSizeAttrInitCode = R"( + assert(odsAttrs && "missing segment size attribute for op"); + auto sizeAttr = odsAttrs.get("{0}").cast(); +)"; +const char *opSegmentSizeAttrInitCode = R"( auto sizeAttr = getAttrOfType("{0}"); +)"; +const char *attrSizedSegmentValueRangeCalcCode = R"( unsigned start = 0; for (unsigned i = 0; i < index; ++i) start += (*(sizeAttr.begin() + i)).getZExtValue(); unsigned size = (*(sizeAttr.begin() + index)).getZExtValue(); - return {{start, size}; + return {start, size}; )"; // The logic to build a range of either operand or result values. @@ -496,15 +502,14 @@ generateValueRangeStartAndEnd(Class &opClass, StringRef methodName, int numVariadic, int numNonVariadic, StringRef rangeSizeCall, bool hasAttrSegmentSize, - StringRef segmentSizeAttr, RangeT &&odsValues) { + StringRef sizeAttrInit, RangeT &&odsValues) { auto &method = opClass.newMethod("std::pair", methodName, "unsigned index"); if (numVariadic == 0) { method.body() << " return {index, 1};\n"; } else if (hasAttrSegmentSize) { - method.body() << formatv(attrSizedSegmentValueRangeCalcCode, - segmentSizeAttr); + method.body() << sizeAttrInit << attrSizedSegmentValueRangeCalcCode; } else { // Because the op can have arbitrarily interleaved variadic and non-variadic // operands, we need to embed a list in the "sink" getter method for @@ -532,6 +537,7 @@ // of ops, in particular for one-operand ops that may not have the // `getOperand(unsigned)` method. static void generateNamedOperandGetters(const Operator &op, Class &opClass, + StringRef sizeAttrInit, StringRef rangeType, StringRef rangeBeginCall, StringRef rangeSizeCall, @@ -563,10 +569,10 @@ // First emit a few "sink" getter methods upon which we layer all nicer named // getter methods. - generateValueRangeStartAndEnd( - opClass, "getODSOperandIndexAndLength", numVariadicOperands, - numNormalOperands, rangeSizeCall, attrSizedOperands, - "operand_segment_sizes", const_cast(op).getOperands()); + generateValueRangeStartAndEnd(opClass, "getODSOperandIndexAndLength", + numVariadicOperands, numNormalOperands, + rangeSizeCall, attrSizedOperands, sizeAttrInit, + const_cast(op).getOperands()); auto &m = opClass.newMethod(rangeType, "getODSOperands", "unsigned index"); m.body() << formatv(valueRangeReturnCode, rangeBeginCall, @@ -574,7 +580,6 @@ // Then we emit nicer named getter methods by redirecting to the "sink" getter // method. - for (int i = 0; i != numOperands; ++i) { const auto &operand = op.getOperand(i); if (operand.name.empty()) @@ -595,11 +600,11 @@ } void OpEmitter::genNamedOperandGetters() { - if (op.getTrait("OpTrait::AttrSizedOperandSegments")) - opClass.setHasOperandAdaptorClass(false); - generateNamedOperandGetters( - op, opClass, /*rangeType=*/"Operation::operand_range", + op, opClass, + /*sizeAttrInit=*/ + formatv(opSegmentSizeAttrInitCode, "operand_segment_sizes").str(), + /*rangeType=*/"Operation::operand_range", /*rangeBeginCall=*/"getOperation()->operand_begin()", /*rangeSizeCall=*/"getOperation()->getNumOperands()", /*getOperandCallPattern=*/"getOperation()->getOperand({0})"); @@ -656,7 +661,8 @@ generateValueRangeStartAndEnd( opClass, "getODSResultIndexAndLength", numVariadicResults, numNormalResults, "getOperation()->getNumResults()", attrSizedResults, - "result_segment_sizes", op.getResults()); + formatv(opSegmentSizeAttrInitCode, "result_segment_sizes").str(), + op.getResults()); auto &m = opClass.newMethod("Operation::result_range", "getODSResults", "unsigned index"); m.body() << formatv(valueRangeReturnCode, "getOperation()->result_begin()", @@ -1840,15 +1846,56 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op) : adapterClass(op.getCppClassName().str() + "OperandAdaptor") { - adapterClass.newField("ArrayRef", "tblgen_operands"); - auto &constructor = adapterClass.newConstructor("ArrayRef values"); - constructor.body() << " tblgen_operands = values;\n"; - - generateNamedOperandGetters(op, adapterClass, + adapterClass.newField("ArrayRef", "odsOperands"); + adapterClass.newField("DictionaryAttr", "odsAttrs"); + const auto *attrSizedOperands = + op.getTrait("OpTrait::AttrSizedOperandSegments"); + auto &constructor = adapterClass.newConstructor( + attrSizedOperands + ? "ArrayRef values, DictionaryAttr attrs" + : "ArrayRef values, DictionaryAttr attrs = nullptr"); + constructor.body() << " odsOperands = values;\n"; + constructor.body() << " odsAttrs = attrs;\n"; + + std::string sizeAttrInit = + formatv(adapterSegmentSizeAttrInitCode, "operand_segment_sizes"); + generateNamedOperandGetters(op, adapterClass, sizeAttrInit, /*rangeType=*/"ArrayRef", - /*rangeBeginCall=*/"tblgen_operands.begin()", - /*rangeSizeCall=*/"tblgen_operands.size()", - /*getOperandCallPattern=*/"tblgen_operands[{0}]"); + /*rangeBeginCall=*/"odsOperands.begin()", + /*rangeSizeCall=*/"odsOperands.size()", + /*getOperandCallPattern=*/"odsOperands[{0}]"); + + FmtContext fctx; + fctx.withBuilder("mlir::Builder(odsAttrs.getContext())"); + + auto emitAttr = [&](StringRef name, Attribute attr) { + auto &body = adapterClass.newMethod(attr.getStorageType(), name).body(); + body << " assert(odsAttrs && \"no attributes when constructing adapter\");" + << "\n " << attr.getStorageType() << " attr = " + << "odsAttrs.get(\"" << name << "\")."; + if (attr.hasDefaultValue() || attr.isOptional()) + body << "dyn_cast_or_null<"; + else + body << "cast<"; + body << attr.getStorageType() << ">();\n"; + + if (attr.hasDefaultValue()) { + // Use the default value if attribute is not set. + // TODO: this is inefficient, we are recreating the attribute for every + // call. This should be set instead. + std::string defaultValue = std::string( + tgfmt(attr.getConstBuilderTemplate(), &fctx, attr.getDefaultValue())); + body << " if (!attr)\n attr = " << defaultValue << ";\n"; + } + body << " return attr;\n"; + }; + + for (auto &namedAttr : op.getAttributes()) { + const auto &name = namedAttr.name; + const auto &attr = namedAttr.attr; + if (!attr.isDerivedAttr()) + emitAttr(name, attr); + } } void OpOperandAdaptorEmitter::emitDecl(const Operator &op, raw_ostream &os) { @@ -1873,19 +1920,13 @@ } for (auto *def : defs) { Operator op(*def); - const auto *attrSizedOperands = - op.getTrait("OpTrait::AttrSizedOperandSegments"); if (emitDecl) { os << formatv(opCommentHeader, op.getQualCppClassName(), "declarations"); - // We cannot generate the operand adaptor class if operand getters depend - // on an attribute. - if (!attrSizedOperands) - OpOperandAdaptorEmitter::emitDecl(op, os); + OpOperandAdaptorEmitter::emitDecl(op, os); OpEmitter::emitDecl(op, os); } else { os << formatv(opCommentHeader, op.getQualCppClassName(), "definitions"); - if (!attrSizedOperands) - OpOperandAdaptorEmitter::emitDef(op, os); + OpOperandAdaptorEmitter::emitDef(op, os); OpEmitter::emitDef(op, os); } }