diff --git a/mlir/test/mlir-tblgen/op-decl.td b/mlir/test/mlir-tblgen/op-decl-and-defs.td rename from mlir/test/mlir-tblgen/op-decl.td rename to mlir/test/mlir-tblgen/op-decl-and-defs.td --- a/mlir/test/mlir-tblgen/op-decl.td +++ b/mlir/test/mlir-tblgen/op-decl-and-defs.td @@ -1,7 +1,9 @@ -// RUN: mlir-tblgen -gen-op-decls -I %S/../../include %s | FileCheck %s +// RUN: mlir-tblgen -gen-op-decls -I %S/../../include %s | FileCheck %s // RUN: mlir-tblgen -gen-op-decls -op-include-regex="test.a_op" -I %S/../../include %s | FileCheck %s --check-prefix=REDUCE_INC // RUN: mlir-tblgen -gen-op-decls -op-exclude-regex="test.a_op" -I %S/../../include %s | FileCheck %s --check-prefix=REDUCE_EXC +// RUN: mlir-tblgen -gen-op-defs -I %S/../../include %s | FileCheck %s --check-prefix=DEFS + include "mlir/IR/OpBase.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" @@ -58,6 +60,8 @@ // CHECK: ::mlir::ValueRange b(); // CHECK: ::mlir::IntegerAttr attr1(); // CHECK: ::mlir::FloatAttr attr2(); +// CHECK: ::mlir::Region &someRegion(); +// CHECK: ::mlir::RegionRange someRegions(); // CHECK: private: // CHECK: ::mlir::ValueRange odsOperands; // CHECK: }; @@ -96,6 +100,13 @@ // CHECK: void displayGraph(); // CHECK: }; +// DEFS-LABEL: NS::AOp definitions + +// DEFS: AOpAdaptor::AOpAdaptor(::mlir::ValueRange values, ::mlir::DictionaryAttr attrs, ::mlir::RegionRange regions) : odsOperands(values), odsAttrs(attrs), odsRegions(regions) +// DEFS: ::mlir::RegionRange AOpAdaptor::getRegions() +// DEFS: ::mlir::RegionRange AOpAdaptor::someRegions() +// DEFS-NEXT: return odsRegions.drop_front(1); + // Check AttrSizedOperandSegments // --- diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -2246,6 +2246,7 @@ : op(op), adaptor(op.getAdaptorName()) { adaptor.newField("::mlir::ValueRange", "odsOperands"); adaptor.newField("::mlir::DictionaryAttr", "odsAttrs"); + adaptor.newField("::mlir::RegionRange", "odsRegions"); const auto *attrSizedOperands = op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments"); { @@ -2253,10 +2254,12 @@ paramList.emplace_back("::mlir::ValueRange", "values"); paramList.emplace_back("::mlir::DictionaryAttr", "attrs", attrSizedOperands ? "" : "nullptr"); + paramList.emplace_back("::mlir::RegionRange", "regions", "{}"); auto *constructor = adaptor.addConstructorAndPrune(std::move(paramList)); constructor->addMemberInitializer("odsOperands", "values"); constructor->addMemberInitializer("odsAttrs", "attrs"); + constructor->addMemberInitializer("odsRegions", "regions"); } { @@ -2264,8 +2267,13 @@ llvm::formatv("{0}&", op.getCppClassName()).str(), "op"); constructor->addMemberInitializer("odsOperands", "op->getOperands()"); constructor->addMemberInitializer("odsAttrs", "op->getAttrDictionary()"); + constructor->addMemberInitializer("odsRegions", "op->getRegions()"); } + { + auto *m = adaptor.addMethodAndPrune("::mlir::ValueRange", "getOperands"); + m->body() << " return odsOperands;"; + } std::string sizeAttrInit = formatv(adapterSegmentSizeAttrInitCode, "operand_segment_sizes"); generateNamedOperandGetters(op, adaptor, sizeAttrInit, @@ -2299,6 +2307,11 @@ body << " return attr;\n"; }; + { + auto *m = + adaptor.addMethodAndPrune("::mlir::DictionaryAttr", "getAttributes"); + m->body() << " return odsAttrs;"; + } for (auto &namedAttr : op.getAttributes()) { const auto &name = namedAttr.name; const auto &attr = namedAttr.attr; @@ -2306,6 +2319,27 @@ emitAttr(name, attr); } + unsigned numRegions = op.getNumRegions(); + if (numRegions > 0) { + auto *m = adaptor.addMethodAndPrune("::mlir::RegionRange", "getRegions"); + m->body() << " return odsRegions;"; + } + for (unsigned i = 0; i < numRegions; ++i) { + const auto ®ion = op.getRegion(i); + if (region.name.empty()) + continue; + + // Generate the accessors for a variadic region. + if (region.isVariadic()) { + auto *m = adaptor.addMethodAndPrune("::mlir::RegionRange", region.name); + m->body() << formatv(" return odsRegions.drop_front({0});", i); + continue; + } + + auto *m = adaptor.addMethodAndPrune("::mlir::Region &", region.name); + m->body() << formatv(" return *odsRegions[{0}];", i); + } + // Add verification function. addVerification(); }