diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -701,15 +701,18 @@ def LLVM_SwitchOp : LLVM_TerminatorOp<"switch", [AttrSizedOperandSegments, DeclareOpInterfaceMethods, NoSideEffect]> { - let arguments = (ins I32:$value, - Variadic:$defaultOperands, - Variadic:$caseOperands, - OptionalAttr:$case_values, - OptionalAttr:$case_operand_offsets, - OptionalAttr:$branch_weights); + let arguments = (ins + I32:$value, + Variadic:$defaultOperands, + VariadicOfVariadic:$caseOperands, + OptionalAttr:$case_values, + OptionalAttr:$case_operand_segments, + OptionalAttr:$branch_weights + ); let successors = (successor - AnySuccessor:$defaultDestination, - VariadicSuccessor:$caseDestinations); + AnySuccessor:$defaultDestination, + VariadicSuccessor:$caseDestinations + ); let verifier = [{ return ::verify(*this); }]; let assemblyFormat = [{ @@ -717,7 +720,7 @@ $defaultDestination (`(` $defaultOperands^ `:` type($defaultOperands) `)`)? `[` `\n` custom($case_values, $caseDestinations, $caseOperands, type($caseOperands), - $case_operand_offsets) `]` + $case_operand_segments) `]` attr-dict }]; @@ -734,11 +737,15 @@ let extraClassDeclaration = [{ /// Return the operands for the case destination block at the given index. - OperandRange getCaseOperands(unsigned index); + OperandRange getCaseOperands(unsigned index) { + return caseOperandsInner(index); + } /// Return a mutable range of operands for the case destination block at the /// given index. - MutableOperandRange getCaseOperandsMutable(unsigned index); + MutableOperandRange getCaseOperandsMutable(unsigned index) { + return caseOperandsInnerMutable(index); + } }]; } diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -1784,14 +1784,17 @@ ``` }]; - let arguments = (ins AnyInteger:$flag, - Variadic:$defaultOperands, - Variadic:$caseOperands, - OptionalAttr:$case_values, - OptionalAttr:$case_operand_offsets); + let arguments = (ins + AnyInteger:$flag, + Variadic:$defaultOperands, + VariadicOfVariadic:$caseOperands, + OptionalAttr:$case_values, + OptionalAttr:$case_operand_segments + ); let successors = (successor - AnySuccessor:$defaultDestination, - VariadicSuccessor:$caseDestinations); + AnySuccessor:$defaultDestination, + VariadicSuccessor:$caseDestinations + ); let builders = [ OpBuilder<(ins "Value":$flag, "Block *":$defaultDestination, @@ -1822,18 +1825,22 @@ $caseDestinations, $caseOperands, type($caseOperands), - $case_operand_offsets) + $case_operand_segments) `]` attr-dict }]; let extraClassDeclaration = [{ /// Return the operands for the case destination block at the given index. - OperandRange getCaseOperands(unsigned index); + OperandRange getCaseOperands(unsigned index) { + return caseOperandsInner(index); + } /// Return a mutable range of operands for the case destination block at the /// given index. - MutableOperandRange getCaseOperandsMutable(unsigned index); + MutableOperandRange getCaseOperandsMutable(unsigned index) { + return caseOperandsInnerMutable(index); + } }]; let hasCanonicalizer = 1; diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -324,6 +324,16 @@ Type baseType = type; } +// A nested variadic type constraint. It expands to zero or more variadic ranges +// of the base type. This class is used for supporting variadic operands and +// results. `variadicSegmentAttrName` should correspond to the name of an +// I32ElementsAttr argument that provides the sizes of the inner variadic +// operand groups. +class VariadicOfVariadic + : Variadic { + string segmentAttrName = variadicSegmentAttrName; +} + // An optional type constraint. It expands to either zero or one of the base // type. This class is used for supporting optional operands/results. class Optional : TypeConstraint { diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -782,6 +782,9 @@ /// Returns the current size of the range. unsigned size() const { return length; } + /// Returns if the current range is empty. + bool empty() const { return size() == 0; } + /// Allow implicit conversion to an OperandRange. operator OperandRange() const; @@ -801,7 +804,7 @@ /// Optional set of operand segments that should be updated when mutating the /// length of this range. - SmallVector, 1> operandSegments; + SmallVector operandSegments; }; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/TableGen/Argument.h b/mlir/include/mlir/TableGen/Argument.h --- a/mlir/include/mlir/TableGen/Argument.h +++ b/mlir/include/mlir/TableGen/Argument.h @@ -48,6 +48,8 @@ bool isOptional() const; // Returns true if this operand/result is variadic. bool isVariadic() const; + // Returns true if this operand/result is a variadic of a variadic constraint. + bool isVariadicOfVariadic() const; // Returns true if this is a variable length type constraint. This is either // variadic or optional. bool isVariableLength() const { return isOptional() || isVariadic(); } diff --git a/mlir/include/mlir/TableGen/Type.h b/mlir/include/mlir/TableGen/Type.h --- a/mlir/include/mlir/TableGen/Type.h +++ b/mlir/include/mlir/TableGen/Type.h @@ -40,6 +40,13 @@ // Returns true if this is a variadic type constraint. bool isVariadic() const; + // Returns true if this is a nested variadic type constraint. + bool isVariadicOfVariadic() const; + + // Return the segment size attribute used if this is a variadic of variadic + // constraint. Asserts isVariadicOfVariadic() is true. + StringRef getVariadicOfVariadicSegmentSizeAttr() const; + // Returns true if this is a variable length type constraint. This is either // variadic or optional. bool isVariableLength() const { return isOptional() || isVariadic(); } diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -520,7 +520,7 @@ /*defaultOperands=*/ValueRange(), /*caseValues=*/caseValues, /*caseDestinations=*/caseDest, - /*caseOperands=*/ArrayRef(), + /*caseOperands=*/ArrayRef({ValueRange(), ValueRange()}), /*branchWeights=*/ArrayRef()); return success(); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -32,6 +32,7 @@ #include "llvm/Support/SourceMgr.h" #include +#include using namespace mlir; using namespace mlir::LLVM; @@ -235,28 +236,16 @@ ArrayRef caseValues, BlockRange caseDestinations, ArrayRef caseOperands, ArrayRef branchWeights) { - SmallVector flattenedCaseOperands; - SmallVector caseOperandOffsets; - int32_t offset = 0; - for (ValueRange operands : caseOperands) { - flattenedCaseOperands.append(operands.begin(), operands.end()); - caseOperandOffsets.push_back(offset); - offset += operands.size(); - } ElementsAttr caseValuesAttr; if (!caseValues.empty()) caseValuesAttr = builder.getI32VectorAttr(caseValues); - ElementsAttr caseOperandOffsetsAttr; - if (!caseOperandOffsets.empty()) - caseOperandOffsetsAttr = builder.getI32VectorAttr(caseOperandOffsets); ElementsAttr weightsAttr; if (!branchWeights.empty()) weightsAttr = builder.getI32VectorAttr(llvm::to_vector<4>(branchWeights)); - build(builder, result, value, defaultOperands, flattenedCaseOperands, - caseValuesAttr, caseOperandOffsetsAttr, weightsAttr, defaultDestination, - caseDestinations); + build(builder, result, value, defaultOperands, caseOperands, caseValuesAttr, + weightsAttr, defaultDestination, caseDestinations); } /// ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)? @@ -266,10 +255,10 @@ SmallVectorImpl &caseDestinations, SmallVectorImpl &caseOperands, SmallVectorImpl &caseOperandTypes, - ElementsAttr &caseOperandOffsets) { + ElementsAttr &caseOperandSegments) { SmallVector values; - SmallVector offsets; - int32_t value, offset = 0; + SmallVector caseSizes; + int32_t value = 0; do { OptionalParseResult integerParseResult = parser.parseOptionalInteger(value); if (values.empty() && !integerParseResult.hasValue()) @@ -290,13 +279,14 @@ } caseDestinations.push_back(destination); caseOperands.append(operands.begin(), operands.end()); - offsets.push_back(offset); - offset += operands.size(); + caseSizes.push_back(operands.size()); } while (!parser.parseOptionalComma()); Builder &builder = parser.getBuilder(); caseValues = builder.getI32VectorAttr(values); - caseOperandOffsets = builder.getI32VectorAttr(offsets); + + if (!caseSizes.empty()) + caseOperandSegments = builder.getI32VectorAttr(caseSizes); return success(); } @@ -306,7 +296,7 @@ SuccessorRange caseDestinations, OperandRange caseOperands, TypeRange caseOperandTypes, - ElementsAttr caseOperandOffsets) { + ElementsAttr caseOperandSegments) { if (!caseValues) return; @@ -341,28 +331,6 @@ return success(); } -OperandRange SwitchOp::getCaseOperands(unsigned index) { - return getCaseOperandsMutable(index); -} - -MutableOperandRange SwitchOp::getCaseOperandsMutable(unsigned index) { - MutableOperandRange caseOperands = caseOperandsMutable(); - if (!case_operand_offsets()) { - assert(caseOperands.size() == 0 && - "non-empty case operands must have offsets"); - return caseOperands; - } - - ElementsAttr offsets = case_operand_offsets().getValue(); - assert(index < offsets.size() && "invalid case operand offset index"); - - int64_t begin = offsets.getValue(index).cast().getInt(); - int64_t end = index + 1 == offsets.size() - ? caseOperands.size() - : offsets.getValue(index + 1).cast().getInt(); - return caseOperandsMutable().slice(begin, end - begin); -} - Optional SwitchOp::getMutableSuccessorOperands(unsigned index) { assert(index < getNumSuccessors() && "invalid successor index"); diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -27,6 +27,7 @@ #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/raw_ostream.h" +#include #include "mlir/Dialect/StandardOps/IR/OpsDialect.cpp.inc" @@ -2056,21 +2057,8 @@ DenseIntElementsAttr caseValues, BlockRange caseDestinations, ArrayRef caseOperands) { - SmallVector flattenedCaseOperands; - SmallVector caseOperandOffsets; - int32_t offset = 0; - for (ValueRange operands : caseOperands) { - flattenedCaseOperands.append(operands.begin(), operands.end()); - caseOperandOffsets.push_back(offset); - offset += operands.size(); - } - DenseIntElementsAttr caseOperandOffsetsAttr; - if (!caseOperandOffsets.empty()) - caseOperandOffsetsAttr = builder.getI32VectorAttr(caseOperandOffsets); - - build(builder, result, value, defaultOperands, flattenedCaseOperands, - caseValues, caseOperandOffsetsAttr, defaultDestination, - caseDestinations); + build(builder, result, value, defaultOperands, caseOperands, caseValues, + defaultDestination, caseDestinations); } void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, @@ -2098,7 +2086,7 @@ SmallVectorImpl &caseDestinations, SmallVectorImpl &caseOperands, SmallVectorImpl &caseOperandTypes, - DenseIntElementsAttr &caseOperandOffsets) { + DenseIntElementsAttr &caseOperandSegments) { if (failed(parser.parseKeyword("default")) || failed(parser.parseColon()) || failed(parser.parseSuccessor(defaultDestination))) return failure(); @@ -2110,9 +2098,8 @@ } SmallVector values; - SmallVector offsets; + SmallVector caseSizes; unsigned bitWidth = flagType.getIntOrFloatBitWidth(); - int64_t offset = 0; while (succeeded(parser.parseOptionalComma())) { int64_t value = 0; if (failed(parser.parseInteger(value))) @@ -2132,8 +2119,7 @@ } caseDestinations.push_back(destination); caseOperands.append(operands.begin(), operands.end()); - offsets.push_back(offset); - offset += operands.size(); + caseSizes.push_back(operands.size()); } if (values.empty()) @@ -2143,7 +2129,7 @@ ShapedType caseValueType = VectorType::get(static_cast(values.size()), flagType); caseValues = DenseIntElementsAttr::get(caseValueType, values); - caseOperandOffsets = builder.getI32VectorAttr(offsets); + caseOperandSegments = builder.getI32VectorAttr(caseSizes); return success(); } @@ -2194,28 +2180,6 @@ return success(); } -OperandRange SwitchOp::getCaseOperands(unsigned index) { - return getCaseOperandsMutable(index); -} - -MutableOperandRange SwitchOp::getCaseOperandsMutable(unsigned index) { - MutableOperandRange caseOperands = caseOperandsMutable(); - if (!case_operand_offsets()) { - assert(caseOperands.size() == 0 && - "non-empty case operands must have offsets"); - return caseOperands; - } - - ElementsAttr offsets = case_operand_offsets().getValue(); - assert(index < offsets.size() && "invalid case operand offset index"); - - int64_t begin = offsets.getValue(index).cast().getInt(); - int64_t end = index + 1 == offsets.size() - ? caseOperands.size() - : offsets.getValue(index + 1).cast().getInt(); - return caseOperandsMutable().slice(begin, end - begin); -} - Optional SwitchOp::getMutableSuccessorOperands(unsigned index) { assert(index < getNumSuccessors() && "invalid successor index"); diff --git a/mlir/lib/TableGen/Argument.cpp b/mlir/lib/TableGen/Argument.cpp --- a/mlir/lib/TableGen/Argument.cpp +++ b/mlir/lib/TableGen/Argument.cpp @@ -12,6 +12,10 @@ using namespace mlir; using namespace mlir::tblgen; +//===----------------------------------------------------------------------===// +// NamedTypeConstraint +//===----------------------------------------------------------------------===// + bool NamedTypeConstraint::hasPredicate() const { return !constraint.getPredicate().isNull(); } @@ -19,3 +23,7 @@ bool NamedTypeConstraint::isOptional() const { return constraint.isOptional(); } bool NamedTypeConstraint::isVariadic() const { return constraint.isVariadic(); } + +bool NamedTypeConstraint::isVariadicOfVariadic() const { + return constraint.isVariadicOfVariadic(); +} diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp --- a/mlir/lib/TableGen/Operator.cpp +++ b/mlir/lib/TableGen/Operator.cpp @@ -458,6 +458,13 @@ results.push_back({name, TypeConstraint(resultDef)}); if (!name.empty()) argumentsAndResultsIndex[name] = resultIndex(i); + + // We currently only support VariadicOfVariadic operands. + if (results.back().constraint.isVariadicOfVariadic()) { + PrintFatalError( + def.getLoc(), + "'VariadicOfVariadic' results are currently not supported"); + } } // Handle successors diff --git a/mlir/lib/TableGen/Type.cpp b/mlir/lib/TableGen/Type.cpp --- a/mlir/lib/TableGen/Type.cpp +++ b/mlir/lib/TableGen/Type.cpp @@ -36,6 +36,15 @@ return def->isSubClassOf("Variadic"); } +bool TypeConstraint::isVariadicOfVariadic() const { + return def->isSubClassOf("VariadicOfVariadic"); +} + +StringRef TypeConstraint::getVariadicOfVariadicSegmentSizeAttr() const { + assert(isVariadicOfVariadic()); + return def->getValueAsString("segmentAttrName"); +} + // Returns the builder call for this constraint if this is a buildable type, // returns None otherwise. Optional TypeConstraint::getBuilderCall() const { diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -24,6 +24,7 @@ #include "llvm/ADT/MapVector.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringSet.h" #include "llvm/Support/Path.h" #include "llvm/Support/Signals.h" #include "llvm/TableGen/Error.h" @@ -90,6 +91,17 @@ unsigned size = *(sizeAttrValues.begin() + index); return {start, size}; )"; +// The logic to calculate the actual value range for a declared operand/result +// of an op with variadic of variadic operands/results. +// +// {0}: The code within attrSizedSegmentValueRangeCalcCode. +// {1}: Additional slice parameters. +const char *attrSizedSegmentVariadicOfVariadicValueRangeCalcCode = R"( + auto operandSegment = [&]() -> std::pair {{ + {0} + }(); + return operands.slice(operandSegment.first, operandSegment.second{1}); +)"; // The logic to build a range of either operand or result values. // @@ -422,16 +434,20 @@ // Builds the parameter list for build() method of this op. This method writes // to `paramList` the comma-separated parameter list and updates // `resultTypeNames` with the names for parameters for specifying result - // types. The given `typeParamKind` and `attrParamKind` controls how result - // types and attributes are placed in the parameter list. + // types. `inferredAttributes` is populated with any attributes that are + // elided from the build list. The given `typeParamKind` and `attrParamKind` + // controls how result types and attributes are placed in the parameter list. void buildParamList(llvm::SmallVectorImpl ¶mList, + llvm::StringSet<> &inferredAttributes, SmallVectorImpl &resultTypeNames, TypeParamKind typeParamKind, AttrParamKind attrParamKind = AttrParamKind::WrappedAttr); // Adds op arguments and regions into operation state for build() methods. - void genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body, - bool isRawValueAttr = false); + void + genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body, + llvm::StringSet<> &inferredAttributes, + bool isRawValueAttr = false); // Generates canonicalizer declaration for the operation. void genCanonicalizerDecls(); @@ -956,7 +972,7 @@ // of ops, in particular for one-operand ops that may not have the // `getOperand(unsigned)` method. static void generateNamedOperandGetters(const Operator &op, Class &opClass, - StringRef sizeAttrInit, + bool isAdaptor, StringRef sizeAttrInit, StringRef rangeType, StringRef rangeBeginCall, StringRef rangeSizeCall, @@ -1014,6 +1030,20 @@ } else if (operand.isVariadic()) { m = opClass.addMethodAndPrune(rangeType, operand.name); m->body() << " return getODSOperands(" << i << ");"; + + // If this operand is a nested variadic, we also generate accessors for + // the inner operands. + if (operand.isVariadicOfVariadic()) { + m = opClass.addMethodAndPrune(rangeType, (operand.name + "Inner").str(), + "unsigned", "index"); + m->body() << " auto operands = " << operand.name << "();\n" + << " auto sizeAttr = " + << operand.constraint.getVariadicOfVariadicSegmentSizeAttr() + << (isAdaptor ? "" : "Attr") << "();\n" + << llvm::formatv( + attrSizedSegmentVariadicOfVariadicValueRangeCalcCode, + attrSizedSegmentValueRangeCalcCode, ""); + } } else { m = opClass.addMethodAndPrune("::mlir::Value", operand.name); m->body() << " return *getODSOperands(" << i << ").begin();"; @@ -1033,6 +1063,7 @@ generateNamedOperandGetters( op, opClass, + /*isAdaptor=*/false, /*sizeAttrInit=*/attrSizeInitCode, /*rangeType=*/"::mlir::Operation::operand_range", /*rangeBeginCall=*/"getOperation()->operand_begin()", @@ -1041,6 +1072,20 @@ } void OpEmitter::genNamedOperandSetters() { + // The logic to calculate the size attribute of a VariadicOfVariadic operand. + // + // {0}: The name of the segment size attribute. + // {1}: The main operand name. + const char *variadicOfVariadicSizeAttrCalcCode = R"( + auto operands = {1}Mutable(); + auto sizeNamedAttr = (*this)->getAttrDictionary().getNamed({0}AttrName()); + if (!sizeNamedAttr) {{ + assert(operands.empty() && "invalid inner operand index"); + return operands; + } + auto sizeAttr = sizeNamedAttr->second.cast<::mlir::DenseElementsAttr>(); +)"; + auto *attrSizedOperands = op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments"); for (int i = 0, e = op.getNumOperands(); i != e; ++i) { @@ -1058,6 +1103,23 @@ << "u, *getOperation()->getAttrDictionary().getNamed(" "operand_segment_sizesAttrName()))"; body << ");\n"; + + // If this operand is a nested variadic, we also generate accessors for + // the inner operands. + if (operand.isVariadicOfVariadic()) { + m = opClass.addMethodAndPrune("::mlir::MutableOperandRange", + (operand.name + "InnerMutable").str(), + "unsigned", "index"); + m->body() << llvm::formatv(variadicOfVariadicSizeAttrCalcCode, + operand.constraint + .getVariadicOfVariadicSegmentSizeAttr(), + operand.name) + << llvm::formatv( + attrSizedSegmentVariadicOfVariadicValueRangeCalcCode, + attrSizedSegmentValueRangeCalcCode, + ", ::mlir::MutableOperandRange::OperandSegment(index, " + "*sizeNamedAttr)"); + } } } @@ -1209,9 +1271,11 @@ // inferring result type. auto emit = [&](AttrParamKind attrType, TypeParamKind paramKind, bool inferType) { - llvm::SmallVector paramList; - llvm::SmallVector resultNames; - buildParamList(paramList, resultNames, paramKind, attrType); + SmallVector paramList; + SmallVector resultNames; + llvm::StringSet<> inferredAttributes; + buildParamList(paramList, inferredAttributes, resultNames, paramKind, + attrType); auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static, std::move(paramList)); @@ -1219,8 +1283,9 @@ if (!m) return; auto &body = m->body(); - genCodeForAddingArgAndRegionForBuilder( - body, /*isRawValueAttr=*/attrType == AttrParamKind::UnwrappedValue); + genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes, + /*isRawValueAttr=*/attrType == + AttrParamKind::UnwrappedValue); // Push all result types to the operation state @@ -1388,7 +1453,9 @@ void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() { llvm::SmallVector paramList; llvm::SmallVector resultNames; - buildParamList(paramList, resultNames, TypeParamKind::None); + llvm::StringSet<> inferredAttributes; + buildParamList(paramList, inferredAttributes, resultNames, + TypeParamKind::None); auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static, std::move(paramList)); @@ -1396,7 +1463,7 @@ if (!m) return; auto &body = m->body(); - genCodeForAddingArgAndRegionForBuilder(body); + genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes); auto numResults = op.getNumResults(); if (numResults == 0) @@ -1588,6 +1655,7 @@ } void OpEmitter::buildParamList(SmallVectorImpl ¶mList, + llvm::StringSet<> &inferredAttributes, SmallVectorImpl &resultTypeNames, TypeParamKind typeParamKind, AttrParamKind attrParamKind) { @@ -1626,10 +1694,6 @@ } // Add parameters for all arguments (operands and attributes). - - int numOperands = 0; - int numAttrs = 0; - int defaultValuedAttrStartIndex = op.getNumArgs(); if (attrParamKind == AttrParamKind::UnwrappedValue) { // Calculate the start index from which we can attach default values in the @@ -1655,54 +1719,68 @@ } } - for (int i = 0, e = op.getNumArgs(); i < e; ++i) { - auto argument = op.getArg(i); - if (argument.is()) { - const auto &operand = op.getOperand(numOperands); - StringRef type = - operand.isVariadic() ? "::mlir::ValueRange" : "::mlir::Value"; - OpMethodParameter::Property properties = OpMethodParameter::PP_None; - if (operand.isOptional()) - properties = OpMethodParameter::PP_Optional; + /// Collect any inferred attributes. + for (const NamedTypeConstraint &operand : op.getOperands()) { + if (operand.isVariadicOfVariadic()) { + inferredAttributes.insert( + operand.constraint.getVariadicOfVariadicSegmentSizeAttr()); + } + } - paramList.emplace_back(type, getArgumentName(op, numOperands), - properties); - ++numOperands; - } else { - const auto &namedAttr = op.getAttribute(numAttrs); - const auto &attr = namedAttr.attr; + for (unsigned i = 0, e = op.getNumArgs(), numOperands = 0; i < e; ++i) { + Argument arg = op.getArg(i); + if (const auto *operand = arg.dyn_cast()) { + StringRef type; + if (operand->isVariadicOfVariadic()) + type = "::llvm::ArrayRef<::mlir::ValueRange>"; + else if (operand->isVariadic()) + type = "::mlir::ValueRange"; + else + type = "::mlir::Value"; OpMethodParameter::Property properties = OpMethodParameter::PP_None; - if (attr.isOptional()) + if (operand->isOptional()) properties = OpMethodParameter::PP_Optional; + paramList.emplace_back(type, getArgumentName(op, numOperands++), + properties); + continue; + } + const NamedAttribute &namedAttr = *arg.get(); + const Attribute &attr = namedAttr.attr; - StringRef type; - switch (attrParamKind) { - case AttrParamKind::WrappedAttr: + // inferred attributes don't need to be added to the param list. + if (inferredAttributes.contains(namedAttr.name)) + continue; + + OpMethodParameter::Property properties = OpMethodParameter::PP_None; + if (attr.isOptional()) + properties = OpMethodParameter::PP_Optional; + + StringRef type; + switch (attrParamKind) { + case AttrParamKind::WrappedAttr: + type = attr.getStorageType(); + break; + case AttrParamKind::UnwrappedValue: + if (canUseUnwrappedRawValue(attr)) + type = attr.getReturnType(); + else type = attr.getStorageType(); - break; - case AttrParamKind::UnwrappedValue: - if (canUseUnwrappedRawValue(attr)) - type = attr.getReturnType(); - else - type = attr.getStorageType(); - break; - } + break; + } - std::string defaultValue; - // Attach default value if requested and possible. - if (attrParamKind == AttrParamKind::UnwrappedValue && - i >= defaultValuedAttrStartIndex) { - bool isString = attr.getReturnType() == "::llvm::StringRef"; - if (isString) - defaultValue.append("\""); - defaultValue += attr.getDefaultValue(); - if (isString) - defaultValue.append("\""); - } - paramList.emplace_back(type, namedAttr.name, defaultValue, properties); - ++numAttrs; + // Attach default value if requested and possible. + std::string defaultValue; + if (attrParamKind == AttrParamKind::UnwrappedValue && + i >= defaultValuedAttrStartIndex) { + bool isString = attr.getReturnType() == "::llvm::StringRef"; + if (isString) + defaultValue.append("\""); + defaultValue += attr.getDefaultValue(); + if (isString) + defaultValue.append("\""); } + paramList.emplace_back(type, namedAttr.name, defaultValue, properties); } /// Insert parameters for each successor. @@ -1719,12 +1797,31 @@ llvm::formatv("{0}Count", region.name).str()); } -void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body, - bool isRawValueAttr) { +void OpEmitter::genCodeForAddingArgAndRegionForBuilder( + OpMethodBody &body, llvm::StringSet<> &inferredAttributes, + bool isRawValueAttr) { // Push all operands to the result. for (int i = 0, e = op.getNumOperands(); i < e; ++i) { std::string argName = getArgumentName(op, i); - if (op.getOperand(i).isOptional()) + NamedTypeConstraint &operand = op.getOperand(i); + if (operand.constraint.isVariadicOfVariadic()) { + body << " for (::mlir::ValueRange range : " << argName << ")\n " + << builderOpState << ".addOperands(range);\n"; + + // Add the segment attribute if operands were provided. + body << " if (!" << argName << ".empty()) {\n" + << " SmallVector rangeSegments;\n" + << " for (::mlir::ValueRange range : " << argName << ")\n" + << " rangeSegments.push_back(range.size());\n" + << " " << builderOpState << ".addAttribute(" + << operand.constraint.getVariadicOfVariadicSegmentSizeAttr() + << "AttrName(" << builderOpState << ".name), " << odsBuilder + << ".getI32VectorAttr(rangeSegments));" + << " }\n"; + continue; + } + + if (operand.isOptional()) body << " if (" << argName << ")\n "; body << " " << builderOpState << ".addOperands(" << argName << ");\n"; } @@ -1736,12 +1833,24 @@ << ".name), " << "odsBuilder.getI32VectorAttr({"; interleaveComma(llvm::seq(0, op.getNumOperands()), body, [&](int i) { - if (op.getOperand(i).isOptional()) - body << "(" << getArgumentName(op, i) << " ? 1 : 0)"; - else if (op.getOperand(i).isVariadic()) - body << "static_cast(" << getArgumentName(op, i) << ".size())"; - else + const NamedTypeConstraint &operand = op.getOperand(i); + if (!operand.isVariableLength()) { body << "1"; + return; + } + + std::string operandName = getArgumentName(op, i); + if (operand.isOptional()) { + body << "(" << operandName << " ? 1 : 0)"; + } else if (operand.isVariadicOfVariadic()) { + body << llvm::formatv( + "static_cast(std::accumulate({0}.begin(), {0}.end(), 0, " + "[](int32_t curSum, ::mlir::ValueRange range) {{ return curSum + " + "range.size(); }))", + operandName); + } else { + body << "static_cast(" << getArgumentName(op, i) << ".size())"; + } }); body << "}));\n"; } @@ -1749,38 +1858,38 @@ // Push all attributes to the result. for (const auto &namedAttr : op.getAttributes()) { auto &attr = namedAttr.attr; - if (!attr.isDerivedAttr()) { - bool emitNotNullCheck = attr.isOptional(); - if (emitNotNullCheck) - body << formatv(" if ({0}) ", namedAttr.name) << "{\n"; - - if (isRawValueAttr && canUseUnwrappedRawValue(attr)) { - // If this is a raw value, then we need to wrap it in an Attribute - // instance. - FmtContext fctx; - fctx.withBuilder("odsBuilder"); - - std::string builderTemplate = - std::string(attr.getConstBuilderTemplate()); - - // For StringAttr, its constant builder call will wrap the input in - // quotes, which is correct for normal string literals, but incorrect - // here given we use function arguments. So we need to strip the - // wrapping quotes. - if (StringRef(builderTemplate).contains("\"$0\"")) - builderTemplate = replaceAllSubstrs(builderTemplate, "\"$0\"", "$0"); - - std::string value = - std::string(tgfmt(builderTemplate, &fctx, namedAttr.name)); - body << formatv(" {0}.addAttribute({1}AttrName({0}.name), {2});\n", - builderOpState, namedAttr.name, value); - } else { - body << formatv(" {0}.addAttribute({1}AttrName({0}.name), {1});\n", - builderOpState, namedAttr.name); - } - if (emitNotNullCheck) - body << " }\n"; + if (attr.isDerivedAttr() || inferredAttributes.contains(namedAttr.name)) + continue; + + bool emitNotNullCheck = attr.isOptional(); + if (emitNotNullCheck) + body << formatv(" if ({0}) ", namedAttr.name) << "{\n"; + + if (isRawValueAttr && canUseUnwrappedRawValue(attr)) { + // If this is a raw value, then we need to wrap it in an Attribute + // instance. + FmtContext fctx; + fctx.withBuilder("odsBuilder"); + + std::string builderTemplate = std::string(attr.getConstBuilderTemplate()); + + // For StringAttr, its constant builder call will wrap the input in + // quotes, which is correct for normal string literals, but incorrect + // here given we use function arguments. So we need to strip the + // wrapping quotes. + if (StringRef(builderTemplate).contains("\"$0\"")) + builderTemplate = replaceAllSubstrs(builderTemplate, "\"$0\"", "$0"); + + std::string value = + std::string(tgfmt(builderTemplate, &fctx, namedAttr.name)); + body << formatv(" {0}.addAttribute({1}AttrName({0}.name), {2});\n", + builderOpState, namedAttr.name, value); + } else { + body << formatv(" {0}.addAttribute({1}AttrName({0}.name), {1});\n", + builderOpState, namedAttr.name); } + if (emitNotNullCheck) + body << " }\n"; } // Create the correct number of regions. @@ -2430,7 +2539,8 @@ } std::string sizeAttrInit = formatv(adapterSegmentSizeAttrInitCode, "operand_segment_sizes"); - generateNamedOperandGetters(op, adaptor, sizeAttrInit, + generateNamedOperandGetters(op, adaptor, + /*isAdaptor=*/true, sizeAttrInit, /*rangeType=*/"::mlir::ValueRange", /*rangeBeginCall=*/"odsOperands.begin()", /*rangeSizeCall=*/"odsOperands.size()",