diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md --- a/mlir/docs/OpDefinitions.md +++ b/mlir/docs/OpDefinitions.md @@ -265,6 +265,24 @@ TODO: Design and implement more primitive constraints +### Operation regions + +The regions of an operation are specified inside of the `dag`-typed `regions`, +led by `region`: + +```tablegen +let regions = (region + :$, + ... +); +``` + +#### Variadic regions + +Similar to the `Variadic` class used for variadic operands and results, +`VariadicRegion<...>` can be used for regions. Variadic regions can currently +only be specified as the last region in the regions list. + ### Operation results Similar to operands, results are specified inside the `dag`-typed `results`, led diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1533,6 +1533,10 @@ CPred<"$_self.getBlocks().size() == " # numBlocks>, "region with " # numBlocks # " blocks">; +// A variadic region constraint. It expands to zero or more of the base region. +class VariadicRegion + : Region; + //===----------------------------------------------------------------------===// // Successor definitions //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -368,6 +368,10 @@ LogicalResult verifyOperandsAreFloatLike(Operation *op); LogicalResult verifyOperandsAreSignlessIntegerLike(Operation *op); LogicalResult verifySameTypeOperands(Operation *op); +LogicalResult verifyZeroRegion(Operation *op); +LogicalResult verifyOneRegion(Operation *op); +LogicalResult verifyNRegions(Operation *op, unsigned numRegions); +LogicalResult verifyAtLeastNRegions(Operation *op, unsigned numRegions); LogicalResult verifyZeroResult(Operation *op); LogicalResult verifyOneResult(Operation *op); LogicalResult verifyNResults(Operation *op, unsigned numOperands); @@ -530,6 +534,89 @@ : public detail::MultiOperandTraitBase {}; //===----------------------------------------------------------------------===// +// Region Traits + +/// This class provides verification for ops that are known to have zero +/// regions. +template +class ZeroRegion : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifyZeroRegion(op); + } +}; + +namespace detail { +/// Utility trait base that provides accessors for derived traits that have +/// multiple regions. +template class TraitType> +struct MultiRegionTraitBase : public TraitBase { + using region_iterator = MutableArrayRef; + using region_range = RegionRange; + + /// Return the number of regions. + unsigned getNumRegions() { return this->getOperation()->getNumRegions(); } + + /// Return the region at `index`. + Region &getRegion(unsigned i) { return this->getOperation()->getRegion(i); } + + /// Region iterator access. + region_iterator region_begin() { + return this->getOperation()->region_begin(); + } + region_iterator region_end() { return this->getOperation()->region_end(); } + region_range getRegions() { return this->getOperation()->getRegions(); } +}; +} // end namespace detail + +/// This class provides APIs for ops that are known to have a single region. +template +class OneRegion : public TraitBase { +public: + Region &getRegion() { return this->getOperation()->getRegion(0); } + + static LogicalResult verifyTrait(Operation *op) { + return impl::verifyOneRegion(op); + } +}; + +/// This class provides the API for ops that are known to have a specified +/// number of regions. +template class NRegions { +public: + static_assert(N > 1, "use ZeroRegion/OneRegion for N < 2"); + + template + class Impl + : public detail::MultiRegionTraitBase::Impl> { + public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifyNRegions(op, N); + } + }; +}; + +/// This class provides APIs for ops that are known to have at least a specified +/// number of regions. +template class AtLeastNRegions { +public: + template + class Impl : public detail::MultiRegionTraitBase::Impl> { + public: + static LogicalResult verifyTrait(Operation *op) { + return impl::verifyAtLeastNRegions(op, N); + } + }; +}; + +/// This class provides the API for ops which have an unknown number of +/// regions. +template +class VariadicRegions + : public detail::MultiRegionTraitBase {}; + +//===----------------------------------------------------------------------===// // Result Traits /// This class provides return value APIs for ops that are known to have diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h --- a/mlir/include/mlir/TableGen/Operator.h +++ b/mlir/include/mlir/TableGen/Operator.h @@ -165,11 +165,20 @@ // requiring the raw MLIR trait here. const OpTrait *getTrait(llvm::StringRef trait) const; + // Regions. + using const_region_iterator = const NamedRegion *; + const_region_iterator region_begin() const; + const_region_iterator region_end() const; + llvm::iterator_range getRegions() const; + // Returns the number of regions. unsigned getNumRegions() const; // Returns the `index`-th region. const NamedRegion &getRegion(unsigned index) const; + // Returns the number of variadic regions in this operation. + unsigned getNumVariadicRegions() const; + // Successors. using const_successor_iterator = const NamedSuccessor *; const_successor_iterator successor_begin() const; diff --git a/mlir/include/mlir/TableGen/Region.h b/mlir/include/mlir/TableGen/Region.h --- a/mlir/include/mlir/TableGen/Region.h +++ b/mlir/include/mlir/TableGen/Region.h @@ -22,10 +22,16 @@ using Constraint::Constraint; static bool classof(const Constraint *c) { return c->getKind() == CK_Region; } + + // Returns true if this region is variadic. + bool isVariadic() const; }; // A struct bundling a region's constraint and its name. struct NamedRegion { + // Returns true if this region is variadic. + bool isVariadic() const { return constraint.isVariadic(); } + StringRef name; Region constraint; }; diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -709,6 +709,32 @@ return success(); } +LogicalResult OpTrait::impl::verifyZeroRegion(Operation *op) { + if (op->getNumRegions() != 0) + return op->emitOpError() << "requires zero regions"; + return success(); +} + +LogicalResult OpTrait::impl::verifyOneRegion(Operation *op) { + if (op->getNumRegions() != 1) + return op->emitOpError() << "requires one region"; + return success(); +} + +LogicalResult OpTrait::impl::verifyNRegions(Operation *op, + unsigned numRegions) { + if (op->getNumRegions() != numRegions) + return op->emitOpError() << "expected " << numRegions << " regions"; + return success(); +} + +LogicalResult OpTrait::impl::verifyAtLeastNRegions(Operation *op, + unsigned numRegions) { + if (op->getNumRegions() < numRegions) + return op->emitOpError() << "expected " << numRegions << " or more regions"; + return success(); +} + LogicalResult OpTrait::impl::verifyZeroResult(Operation *op) { if (op->getNumResults() != 0) return op->emitOpError() << "requires zero results"; diff --git a/mlir/lib/TableGen/CMakeLists.txt b/mlir/lib/TableGen/CMakeLists.txt --- a/mlir/lib/TableGen/CMakeLists.txt +++ b/mlir/lib/TableGen/CMakeLists.txt @@ -11,6 +11,7 @@ Pass.cpp Pattern.cpp Predicate.cpp + Region.cpp SideEffects.cpp Successor.cpp Type.cpp diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp --- a/mlir/lib/TableGen/Operator.cpp +++ b/mlir/lib/TableGen/Operator.cpp @@ -173,12 +173,28 @@ return nullptr; } +auto tblgen::Operator::region_begin() const -> const_region_iterator { + return regions.begin(); +} +auto tblgen::Operator::region_end() const -> const_region_iterator { + return regions.end(); +} +auto tblgen::Operator::getRegions() const + -> llvm::iterator_range { + return {region_begin(), region_end()}; +} + unsigned tblgen::Operator::getNumRegions() const { return regions.size(); } const tblgen::NamedRegion &tblgen::Operator::getRegion(unsigned index) const { return regions[index]; } +unsigned tblgen::Operator::getNumVariadicRegions() const { + return llvm::count_if(regions, + [](const NamedRegion &c) { return c.isVariadic(); }); +} + auto tblgen::Operator::successor_begin() const -> const_successor_iterator { return successors.begin(); } @@ -388,7 +404,16 @@ PrintFatalError(def.getLoc(), Twine("undefined kind for region #") + Twine(i)); } - regions.push_back({name, Region(regionInit->getDef())}); + Region region(regionInit->getDef()); + if (region.isVariadic()) { + // Only support variadic regions if it is the last one for now. + if (i != e - 1) + PrintFatalError(def.getLoc(), "only the last region can be variadic"); + if (name.empty()) + PrintFatalError(def.getLoc(), "variadic regions must be named"); + } + + regions.push_back({name, region}); } LLVM_DEBUG(print(llvm::dbgs())); diff --git a/mlir/lib/TableGen/Region.cpp b/mlir/lib/TableGen/Region.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/TableGen/Region.cpp @@ -0,0 +1,20 @@ +//===- Region.cpp - Region class ------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Region wrapper to simplify using TableGen Record defining a MLIR Region. +// +//===----------------------------------------------------------------------===// + +#include "mlir/TableGen/Region.h" +#include "llvm/TableGen/Record.h" + +using namespace mlir; +using namespace mlir::tblgen; + +// Returns true if this region is variadic. +bool Region::isVariadic() const { return def->isSubClassOf("VariadicRegion"); } diff --git a/mlir/test/Dialect/LLVMIR/func.mlir b/mlir/test/Dialect/LLVMIR/func.mlir --- a/mlir/test/Dialect/LLVMIR/func.mlir +++ b/mlir/test/Dialect/LLVMIR/func.mlir @@ -118,7 +118,7 @@ // ----- module { - // expected-error@+1 {{expects one region}} + // expected-error@+1 {{requires one region}} "llvm.func"() {sym_name = "no_region", type = !llvm<"void ()">} : () -> () } diff --git a/mlir/test/Dialect/LLVMIR/global.mlir b/mlir/test/Dialect/LLVMIR/global.mlir --- a/mlir/test/Dialect/LLVMIR/global.mlir +++ b/mlir/test/Dialect/LLVMIR/global.mlir @@ -60,12 +60,12 @@ // ----- // expected-error @+1 {{op requires string attribute 'sym_name'}} -"llvm.mlir.global"() {type = !llvm.i64, constant, value = 42 : i64} : () -> () +"llvm.mlir.global"() ({}) {type = !llvm.i64, constant, value = 42 : i64} : () -> () // ----- // expected-error @+1 {{op requires attribute 'type'}} -"llvm.mlir.global"() {sym_name = "foo", constant, value = 42 : i64} : () -> () +"llvm.mlir.global"() ({}) {sym_name = "foo", constant, value = 42 : i64} : () -> () // ----- @@ -75,12 +75,12 @@ // ----- // expected-error @+1 {{'addr_space' failed to satisfy constraint: 32-bit signless integer attribute whose value is non-negative}} -"llvm.mlir.global"() {sym_name = "foo", type = !llvm.i64, value = 42 : i64, addr_space = -1 : i32, linkage = 0} : () -> () +"llvm.mlir.global"() ({}) {sym_name = "foo", type = !llvm.i64, value = 42 : i64, addr_space = -1 : i32, linkage = 0} : () -> () // ----- // expected-error @+1 {{'addr_space' failed to satisfy constraint: 32-bit signless integer attribute whose value is non-negative}} -"llvm.mlir.global"() {sym_name = "foo", type = !llvm.i64, value = 42 : i64, addr_space = 1.0 : f32, linkage = 0} : () -> () +"llvm.mlir.global"() ({}) {sym_name = "foo", type = !llvm.i64, value = 42 : i64, addr_space = 1.0 : f32, linkage = 0} : () -> () // ----- diff --git a/mlir/test/Dialect/Loops/invalid.mlir b/mlir/test/Dialect/Loops/invalid.mlir --- a/mlir/test/Dialect/Loops/invalid.mlir +++ b/mlir/test/Dialect/Loops/invalid.mlir @@ -2,7 +2,7 @@ func @loop_for_lb(%arg0: f32, %arg1: index) { // expected-error@+1 {{operand #0 must be index}} - "loop.for"(%arg0, %arg1, %arg1) : (f32, index, index) -> () + "loop.for"(%arg0, %arg1, %arg1) ({}) : (f32, index, index) -> () return } @@ -10,7 +10,7 @@ func @loop_for_ub(%arg0: f32, %arg1: index) { // expected-error@+1 {{operand #1 must be index}} - "loop.for"(%arg1, %arg0, %arg1) : (index, f32, index) -> () + "loop.for"(%arg1, %arg0, %arg1) ({}) : (index, f32, index) -> () return } @@ -18,7 +18,7 @@ func @loop_for_step(%arg0: f32, %arg1: index) { // expected-error@+1 {{operand #2 must be index}} - "loop.for"(%arg1, %arg1, %arg0) : (index, index, f32) -> () + "loop.for"(%arg1, %arg1, %arg0) ({}) : (index, index, f32) -> () return } @@ -37,7 +37,7 @@ // ----- func @loop_for_one_region(%arg0: index) { - // expected-error@+1 {{incorrect number of regions: expected 1 but found 2}} + // expected-error@+1 {{requires one region}} "loop.for"(%arg0, %arg0, %arg0) ( {loop.yield}, {loop.yield} @@ -77,14 +77,14 @@ func @loop_if_not_i1(%arg0: index) { // expected-error@+1 {{operand #0 must be 1-bit signless integer}} - "loop.if"(%arg0) : (index) -> () + "loop.if"(%arg0) ({}, {}) : (index) -> () return } // ----- func @loop_if_more_than_2_regions(%arg0: i1) { - // expected-error@+1 {{op has incorrect number of regions: expected 2}} + // expected-error@+1 {{expected 2 regions}} "loop.if"(%arg0) ({}, {}, {}): (i1) -> () return } diff --git a/mlir/test/IR/region.mlir b/mlir/test/IR/region.mlir --- a/mlir/test/IR/region.mlir +++ b/mlir/test/IR/region.mlir @@ -16,7 +16,7 @@ // ----- func @missing_regions() { - // expected-error@+1 {{op has incorrect number of regions: expected 2 but found 1}} + // expected-error@+1 {{expected 2 regions}} "test.two_region_op"()( {"work"() : () -> ()} ) : () -> () @@ -26,7 +26,7 @@ // ----- func @extra_regions() { - // expected-error@+1 {{op has incorrect number of regions: expected 2 but found 3}} + // expected-error@+1 {{expected 2 regions}} "test.two_region_op"()( {"work"() : () -> ()}, {"work"() : () -> ()}, diff --git a/mlir/test/mlir-tblgen/op-decl.td b/mlir/test/mlir-tblgen/op-decl.td --- a/mlir/test/mlir-tblgen/op-decl.td +++ b/mlir/test/mlir-tblgen/op-decl.td @@ -26,7 +26,10 @@ Variadic:$s ); - let regions = (region AnyRegion:$someRegion); + let regions = (region + AnyRegion:$someRegion, + VariadicRegion:$someRegions + ); let builders = [OpBuilder<"Value val">]; let parser = [{ foo }]; let printer = [{ bar }]; @@ -55,7 +58,7 @@ // CHECK: ArrayRef tblgen_operands; // CHECK: }; -// CHECK: class AOp : public Op::Impl, OpTrait::ZeroSuccessor, OpTrait::AtLeastNOperands<1>::Impl, OpTrait::IsIsolatedFromAbove +// CHECK: class AOp : public Op::Impl, OpTrait::AtLeastNResults<1>::Impl, OpTrait::ZeroSuccessor, OpTrait::AtLeastNOperands<1>::Impl, OpTrait::IsIsolatedFromAbove // CHECK-NOT: OpTrait::IsIsolatedFromAbove // CHECK: public: // CHECK: using Op::Op; @@ -67,14 +70,15 @@ // CHECK: Operation::result_range getODSResults(unsigned index); // CHECK: Value r(); // CHECK: Region &someRegion(); +// CHECK: MutableArrayRef someRegions(); // CHECK: IntegerAttr attr1Attr() // CHECK: APInt attr1(); // CHECK: FloatAttr attr2Attr() // CHECK: Optional< APFloat > attr2(); // CHECK: static void build(Value val); -// CHECK: static void build(Builder *odsBuilder, OperationState &odsState, Type r, ArrayRef s, Value a, ValueRange b, IntegerAttr attr1, /*optional*/FloatAttr attr2) -// CHECK: static void build(Builder *odsBuilder, OperationState &odsState, Type r, ArrayRef s, Value a, ValueRange b, APInt attr1, /*optional*/FloatAttr attr2) -// CHECK: static void build(Builder *, OperationState &odsState, ArrayRef resultTypes, ValueRange operands, ArrayRef attributes) +// CHECK: static void build(Builder *odsBuilder, OperationState &odsState, Type r, ArrayRef s, Value a, ValueRange b, IntegerAttr attr1, /*optional*/FloatAttr attr2, unsigned someRegionsCount) +// CHECK: static void build(Builder *odsBuilder, OperationState &odsState, Type r, ArrayRef s, Value a, ValueRange b, APInt attr1, /*optional*/FloatAttr attr2, unsigned someRegionsCount) +// CHECK: static void build(Builder *, OperationState &odsState, ArrayRef resultTypes, ValueRange operands, ArrayRef attributes, unsigned numRegions) // CHECK: static ParseResult parse(OpAsmParser &parser, OperationState &result); // CHECK: void print(OpAsmPrinter &p); // CHECK: LogicalResult verify(); diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -603,10 +603,19 @@ unsigned numRegions = op.getNumRegions(); for (unsigned i = 0; i < numRegions; ++i) { const auto ®ion = op.getRegion(i); - if (!region.name.empty()) { - auto &m = opClass.newMethod("Region &", region.name); - m.body() << formatv(" return this->getOperation()->getRegion({0});", i); + if (region.name.empty()) + continue; + + // Generate the accessors for a varidiadic region. + if (region.isVariadic()) { + auto &m = opClass.newMethod("MutableArrayRef", region.name); + m.body() << formatv( + " return this->getOperation()->getRegions().drop_front({0});", i); + continue; } + + auto &m = opClass.newMethod("Region &", region.name); + m.body() << formatv(" return this->getOperation()->getRegion({0});", i); } } @@ -739,6 +748,8 @@ std::string params = std::string("Builder *odsBuilder, OperationState &") + builderOpState + ", ValueRange operands, ArrayRef attributes"; + if (op.getNumVariadicRegions()) + params += ", unsigned numRegions"; auto &m = opClass.newMethod("void", "build", params, OpMethod::MP_Static); auto &body = m.body(); @@ -750,8 +761,10 @@ // Create the correct number of regions if (int numRegions = op.getNumRegions()) { - for (int i = 0; i < numRegions; ++i) - m.body() << " (void)" << builderOpState << ".addRegion();\n"; + body << llvm::formatv( + " for (unsigned i = 0; i != {0}; ++i)\n", + (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions))); + body << " (void)" << builderOpState << ".addRegion();\n"; } // Result types @@ -897,6 +910,8 @@ builderOpState + ", ArrayRef resultTypes, ValueRange operands, " "ArrayRef attributes"; + if (op.getNumVariadicRegions()) + params += ", unsigned numRegions"; auto &m = opClass.newMethod("void", "build", params, OpMethod::MP_Static); auto &body = m.body(); @@ -913,8 +928,10 @@ // Create the correct number of regions if (int numRegions = op.getNumRegions()) { - for (int i = 0; i < numRegions; ++i) - m.body() << " (void)" << builderOpState << ".addRegion();\n"; + body << llvm::formatv( + " for (unsigned i = 0; i != {0}; ++i)\n", + (op.getNumVariadicRegions() ? "numRegions" : Twine(numRegions))); + body << " (void)" << builderOpState << ".addRegion();\n"; } // Result types @@ -1042,11 +1059,17 @@ } } - /// Insert parameters for the block and operands for each successor. + /// Insert parameters for each successor. for (const NamedSuccessor &succ : op.getSuccessors()) { paramList += (succ.isVariadic() ? ", ArrayRef " : ", Block *"); paramList += succ.name; } + + /// Insert parameters for variadic regions. + for (const NamedRegion ®ion : op.getRegions()) { + if (region.isVariadic()) + paramList += llvm::formatv(", unsigned {0}Count", region.name).str(); + } } void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body, @@ -1110,9 +1133,12 @@ } // Create the correct number of regions. - if (int numRegions = op.getNumRegions()) { - for (int i = 0; i < numRegions; ++i) - body << " (void)" << builderOpState << ".addRegion();\n"; + for (const NamedRegion ®ion : op.getRegions()) { + if (region.isVariadic()) + body << formatv(" for (unsigned i = 0; i < {0}Count; ++i)\n ", + region.name); + + body << " (void)" << builderOpState << ".addRegion();\n"; } // Push all successors to the result. @@ -1436,33 +1462,42 @@ } void OpEmitter::genRegionVerifier(OpMethodBody &body) { + // If we have no regions, there is nothing more to do. unsigned numRegions = op.getNumRegions(); + if (numRegions == 0) + return; - // Verify this op has the correct number of regions - body << formatv( - " if (this->getOperation()->getNumRegions() != {0}) {\n " - "return emitOpError(\"has incorrect number of regions: expected {0} but " - "found \") << this->getOperation()->getNumRegions();\n }\n", - numRegions); + body << "{\n"; + body << " unsigned index = 0; (void)index;\n"; for (unsigned i = 0; i < numRegions; ++i) { const auto ®ion = op.getRegion(i); + if (region.constraint.getPredicate().isNull()) + continue; - std::string name = std::string(formatv("#{0}", i)); - if (!region.name.empty()) { - name += std::string(formatv(" ('{0}')", region.name)); - } - - auto getRegion = formatv("this->getOperation()->getRegion({0})", i).str(); + body << " for (Region ®ion : "; + body << formatv( + region.isVariadic() + ? "{0}()" + : "MutableArrayRef(this->getOperation()->getRegion({1}))", + region.name, i); + body << ") {\n"; auto constraint = tgfmt(region.constraint.getConditionTemplate(), - &verifyCtx.withSelf(getRegion)) + &verifyCtx.withSelf("region")) .str(); - body << formatv(" if (!({0})) {\n " - "return emitOpError(\"region {1} failed to verify " - "constraint: {2}\");\n }\n", - constraint, name, region.constraint.getDescription()); + body << formatv(" (void)region;\n" + " if (!({0})) {\n " + "return emitOpError(\"region #\") << index << \" {1}" + "failed to " + "verify constraint: {2}\";\n }\n", + constraint, + region.name.empty() ? "" : "('" + region.name + "') ", + region.constraint.getDescription()) + << " ++index;\n" + << " }\n"; } + body << " }\n"; } void OpEmitter::genSuccessorVerifier(OpMethodBody &body) { @@ -1488,29 +1523,31 @@ &verifyCtx.withSelf("successor")) .str(); - body << formatv( - " (void)successor;\n" - " if (!({0})) {\n " - "return emitOpError(\"successor #\") << index << \"('{2}') failed to " - "verify constraint: {3}\";\n }\n", - constraint, i, successor.name, successor.constraint.getDescription()); - body << " }\n"; + body << formatv(" (void)successor;\n" + " if (!({0})) {\n " + "return emitOpError(\"successor #\") << index << \"('{1}') " + "failed to " + "verify constraint: {2}\";\n }\n", + constraint, successor.name, + successor.constraint.getDescription()) + << " ++index;\n" + << " }\n"; } body << " }\n"; } /// Add a size count trait to the given operation class. static void addSizeCountTrait(OpClass &opClass, StringRef traitKind, - int numNonVariadic, int numVariadic) { + int numTotal, int numVariadic) { if (numVariadic != 0) { - if (numNonVariadic == numVariadic) + if (numTotal == numVariadic) opClass.addTrait("OpTrait::Variadic" + traitKind + "s"); else opClass.addTrait("OpTrait::AtLeastN" + traitKind + "s<" + - Twine(numNonVariadic - numVariadic) + ">::Impl"); + Twine(numTotal - numVariadic) + ">::Impl"); return; } - switch (numNonVariadic) { + switch (numTotal) { case 0: opClass.addTrait("OpTrait::Zero" + traitKind); break; @@ -1518,17 +1555,21 @@ opClass.addTrait("OpTrait::One" + traitKind); break; default: - opClass.addTrait("OpTrait::N" + traitKind + "s<" + Twine(numNonVariadic) + + opClass.addTrait("OpTrait::N" + traitKind + "s<" + Twine(numTotal) + ">::Impl"); break; } } void OpEmitter::genTraits() { + // Add region size trait. + unsigned numRegions = op.getNumRegions(); + unsigned numVariadicRegions = op.getNumVariadicRegions(); + addSizeCountTrait(opClass, "Region", numRegions, numVariadicRegions); + + // Add result size trait. int numResults = op.getNumResults(); int numVariadicResults = op.getNumVariadicResults(); - - // Add return size trait. addSizeCountTrait(opClass, "Result", numResults, numVariadicResults); // Add successor size trait.