diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md --- a/mlir/docs/OpDefinitions.md +++ b/mlir/docs/OpDefinitions.md @@ -229,6 +229,17 @@ indicate that all variable length operands have the same number of dynamic values. +#### VariadicOfVariadic operands + +To declare a variadic operand that has a variadic number of sub-ranges, wrap the +`TypeConstraint` for the operand with `VariadicOfVariadic<..., +"">`. + +The second field of the `VariadicOfVariadic` is the name of an `I32ElementsAttr` +argument that contains the sizes of the variadic sub-ranges. This attribute will +be used when determining the size of sub-ranges, or when updating the size of +sub-ranges. + #### Optional operands To declare an optional operand, wrap the `TypeConstraint` for the operand with @@ -717,6 +728,8 @@ - Single: `OpAsmParser::OperandType &` - Optional: `Optional &` - Variadic: `SmallVectorImpl &` + - VariadicOfVariadic: + `SmallVectorImpl> &` * Ref Directives - A reference directive is passed to the parser using the same mapping as the input operand. For example, a single region would be passed as a @@ -731,6 +744,7 @@ - Single: `Type &` - Optional: `Type &` - Variadic: `SmallVectorImpl &` + - VariadicOfVariadic: `SmallVectorImpl> &` * `attr-dict` Directive: `NamedAttrList &` When a variable is optional, the value should only be specified if the variable @@ -749,6 +763,7 @@ - Single: `Value` - Optional: `Value` - Variadic: `OperandRange` + - VariadicOfVariadic: `OperandRangeRange` * Ref Directives - A reference directive is passed to the printer using the same mapping as the input operand. For example, a single region would be passed as a @@ -763,6 +778,7 @@ - Single: `Type` - Optional: `Type` - Variadic: `TypeRange` + - VariadicOfVariadic: `TypeRangeRange` * `attr-dict` Directive: `DictionaryAttr` When a variable is optional, the provided value may be null. @@ -923,7 +939,7 @@ When this boolean field is set to `true`, it indicates that the op implements a `canonicalize` method for simple "matchAndRewrite" style canonicalization -patterns. If `hasCanonicalizer` is 0, then an implementation of +patterns. If `hasCanonicalizer` is 0, then an implementation of `::getCanonicalizationPatterns()` is implemented to call this function. ### `hasFolder` diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -701,23 +701,25 @@ def LLVM_SwitchOp : LLVM_TerminatorOp<"switch", [AttrSizedOperandSegments, DeclareOpInterfaceMethods, NoSideEffect]> { - let arguments = (ins I32:$value, - Variadic:$defaultOperands, - Variadic:$caseOperands, - OptionalAttr:$case_values, - OptionalAttr:$case_operand_offsets, - OptionalAttr:$branch_weights); + let arguments = (ins + I32:$value, + Variadic:$defaultOperands, + VariadicOfVariadic:$caseOperands, + OptionalAttr:$case_values, + ElementsAttr:$case_operand_segments, + OptionalAttr:$branch_weights + ); let successors = (successor - AnySuccessor:$defaultDestination, - VariadicSuccessor:$caseDestinations); + AnySuccessor:$defaultDestination, + VariadicSuccessor:$caseDestinations + ); let verifier = [{ return ::verify(*this); }]; let assemblyFormat = [{ $value `,` $defaultDestination (`(` $defaultOperands^ `:` type($defaultOperands) `)`)? `[` `\n` custom($case_values, $caseDestinations, - $caseOperands, type($caseOperands), - $case_operand_offsets) `]` + $caseOperands, type($caseOperands)) `]` attr-dict }]; @@ -734,11 +736,15 @@ let extraClassDeclaration = [{ /// Return the operands for the case destination block at the given index. - OperandRange getCaseOperands(unsigned index); + OperandRange getCaseOperands(unsigned index) { + return caseOperands()[index]; + } /// Return a mutable range of operands for the case destination block at the /// given index. - MutableOperandRange getCaseOperandsMutable(unsigned index); + MutableOperandRange getCaseOperandsMutable(unsigned index) { + return caseOperandsMutable()[index]; + } }]; } diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -1784,14 +1784,17 @@ ``` }]; - let arguments = (ins AnyInteger:$flag, - Variadic:$defaultOperands, - Variadic:$caseOperands, - OptionalAttr:$case_values, - OptionalAttr:$case_operand_offsets); + let arguments = (ins + AnyInteger:$flag, + Variadic:$defaultOperands, + VariadicOfVariadic:$caseOperands, + OptionalAttr:$case_values, + I32ElementsAttr:$case_operand_segments + ); let successors = (successor - AnySuccessor:$defaultDestination, - VariadicSuccessor:$caseDestinations); + AnySuccessor:$defaultDestination, + VariadicSuccessor:$caseDestinations + ); let builders = [ OpBuilder<(ins "Value":$flag, "Block *":$defaultDestination, @@ -1821,19 +1824,22 @@ $case_values, $caseDestinations, $caseOperands, - type($caseOperands), - $case_operand_offsets) + type($caseOperands)) `]` attr-dict }]; let extraClassDeclaration = [{ /// Return the operands for the case destination block at the given index. - OperandRange getCaseOperands(unsigned index); + OperandRange getCaseOperands(unsigned index) { + return caseOperands()[index]; + } /// Return a mutable range of operands for the case destination block at the /// given index. - MutableOperandRange getCaseOperandsMutable(unsigned index); + MutableOperandRange getCaseOperandsMutable(unsigned index) { + return caseOperandsMutable()[index]; + } }]; let hasCanonicalizer = 1; diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -324,6 +324,16 @@ Type baseType = type; } +// A nested variadic type constraint. It expands to zero or more variadic ranges +// of the base type. This class is used for supporting variadic operands and +// results. `variadicSegmentAttrName` should correspond to the name of an +// I32ElementsAttr argument that provides the sizes of the inner variadic +// operand groups. +class VariadicOfVariadic + : Variadic { + string segmentAttrName = variadicSegmentAttrName; +} + // An optional type constraint. It expands to either zero or one of the base // type. This class is used for supporting optional operands/results. class Optional : TypeConstraint { diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -36,12 +36,14 @@ class Dialect; class DictionaryAttr; class ElementsAttr; +class MutableOperandRangeRange; class Operation; struct OperationState; class OpAsmParser; class OpAsmParserResult; class OpAsmPrinter; class OperandRange; +class OperandRangeRange; class OpFoldResult; class ParseResult; class Pattern; @@ -727,6 +729,10 @@ /// must not be empty. unsigned getBeginOperandIndex() const; + /// Split this range into a set of contiguous subranges using the given + /// elements attribute, which contains the sizes of the sub ranges. + OperandRangeRange split(ElementsAttr segmentSizes) const; + private: /// See `llvm::detail::indexed_accessor_range_base` for details. static OpOperand *offset_base(OpOperand *object, ptrdiff_t index) { @@ -741,6 +747,42 @@ friend RangeBaseT; }; +//===----------------------------------------------------------------------===// +// OperandRangeRange + +/// This class represents a contiguous range of operand ranges, e.g. from a +/// VariadicOfVariadic operand group. +class OperandRangeRange final + : public llvm::indexed_accessor_range< + OperandRangeRange, std::pair, OperandRange, + OperandRange, OperandRange> { + using OwnerT = std::pair; + using RangeBaseT = + llvm::indexed_accessor_range; + +public: + using RangeBaseT::RangeBaseT; + + /// Returns the range of types of the values within this range. + TypeRangeRange getTypes() const { return TypeRangeRange(*this); } + auto getType() const { return getTypes(); } + + /// Construct a range given a parent set of operands, and an I32 elements + /// attribute containing the sizes of the sub ranges. + OperandRangeRange(OperandRange operands, Attribute operandSegments); + + /// Flatten all of the sub ranges into a single contiguous operand range. + OperandRange join() const; + +private: + /// See `llvm::indexed_accessor_range` for details. + static OperandRange dereference(const OwnerT &object, ptrdiff_t index); + + /// Allow access to `dereference_iterator`. + friend RangeBaseT; +}; + //===----------------------------------------------------------------------===// // MutableOperandRange @@ -761,8 +803,9 @@ MutableOperandRange(Operation *owner); /// Slice this range into a sub range, with the additional operand segment. - MutableOperandRange slice(unsigned subStart, unsigned subLen, - Optional segment = llvm::None); + MutableOperandRange + slice(unsigned subStart, unsigned subLen, + Optional segment = llvm::None) const; /// Append the given values to the range. void append(ValueRange values); @@ -782,12 +825,19 @@ /// Returns the current size of the range. unsigned size() const { return length; } + /// Returns if the current range is empty. + bool empty() const { return size() == 0; } + /// Allow implicit conversion to an OperandRange. operator OperandRange() const; /// Returns the owning operation. Operation *getOwner() const { return owner; } + /// Split this range into a set of contiguous subranges using the given + /// elements attribute, which contains the sizes of the sub ranges. + MutableOperandRangeRange split(NamedAttribute segmentSizes) const; + private: /// Update the length of this range to the one provided. void updateLength(unsigned newLength); @@ -801,7 +851,46 @@ /// Optional set of operand segments that should be updated when mutating the /// length of this range. - SmallVector, 1> operandSegments; + SmallVector operandSegments; +}; + +//===----------------------------------------------------------------------===// +// MutableOperandRangeRange + +/// This class represents a contiguous range of mutable operand ranges, e.g. +/// from a VariadicOfVariadic operand group. +class MutableOperandRangeRange final + : public llvm::indexed_accessor_range< + MutableOperandRangeRange, + std::pair, MutableOperandRange, + MutableOperandRange, MutableOperandRange> { + using OwnerT = std::pair; + using RangeBaseT = + llvm::indexed_accessor_range; + +public: + using RangeBaseT::RangeBaseT; + + /// Construct a range given a parent set of operands, and an I32 tensor + /// elements attribute containing the sizes of the sub ranges. + MutableOperandRangeRange(const MutableOperandRange &operands, + NamedAttribute operandSegmentAttr); + + /// Flatten all of the sub ranges into a single contiguous mutable operand + /// range. + MutableOperandRange join() const; + + /// Allow implicit conversion to an OperandRangeRange. + operator OperandRangeRange() const; + +private: + /// See `llvm::indexed_accessor_range` for details. + static MutableOperandRange dereference(const OwnerT &object, ptrdiff_t index); + + /// Allow access to `dereference_iterator`. + friend RangeBaseT; }; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/TypeRange.h b/mlir/include/mlir/IR/TypeRange.h --- a/mlir/include/mlir/IR/TypeRange.h +++ b/mlir/include/mlir/IR/TypeRange.h @@ -16,6 +16,7 @@ #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" #include "llvm/ADT/PointerUnion.h" +#include "llvm/ADT/Sequence.h" namespace mlir { class OperandRange; @@ -88,6 +89,35 @@ return os; } +//===----------------------------------------------------------------------===// +// TypeRangeRange + +using TypeRangeRangeIterator = + llvm::mapped_iterator::iterator, + std::function>; + +/// This class provides an abstraction for a range of TypeRange. This is useful +/// when accessing the types of a range of ranges, such as when using +/// OperandRangeRange. +class TypeRangeRange : public llvm::iterator_range { +public: + template + TypeRangeRange(const RangeT &range) + : TypeRangeRange(llvm::seq(0, range.size()), range) {} + +private: + template + TypeRangeRange(llvm::iota_range sizeRange, const RangeT &range) + : llvm::iterator_range( + {sizeRange.begin(), getRangeFn(range)}, + {sizeRange.end(), nullptr}) {} + + template + static std::function getRangeFn(const RangeT &range) { + return [=](unsigned index) -> TypeRange { return TypeRange(range[index]); }; + } +}; + //===----------------------------------------------------------------------===// // ValueTypeRange diff --git a/mlir/include/mlir/TableGen/Argument.h b/mlir/include/mlir/TableGen/Argument.h --- a/mlir/include/mlir/TableGen/Argument.h +++ b/mlir/include/mlir/TableGen/Argument.h @@ -48,6 +48,8 @@ bool isOptional() const; // Returns true if this operand/result is variadic. bool isVariadic() const; + // Returns true if this operand/result is a variadic of a variadic constraint. + bool isVariadicOfVariadic() const; // Returns true if this is a variable length type constraint. This is either // variadic or optional. bool isVariableLength() const { return isOptional() || isVariadic(); } diff --git a/mlir/include/mlir/TableGen/Type.h b/mlir/include/mlir/TableGen/Type.h --- a/mlir/include/mlir/TableGen/Type.h +++ b/mlir/include/mlir/TableGen/Type.h @@ -40,6 +40,13 @@ // Returns true if this is a variadic type constraint. bool isVariadic() const; + // Returns true if this is a nested variadic type constraint. + bool isVariadicOfVariadic() const; + + // Return the segment size attribute used if this is a variadic of variadic + // constraint. Asserts isVariadicOfVariadic() is true. + StringRef getVariadicOfVariadicSegmentSizeAttr() const; + // Returns true if this is a variable length type constraint. This is either // variadic or optional. bool isVariableLength() const { return isOptional() || isVariadic(); } diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -520,7 +520,7 @@ /*defaultOperands=*/ValueRange(), /*caseValues=*/caseValues, /*caseDestinations=*/caseDest, - /*caseOperands=*/ArrayRef(), + /*caseOperands=*/ArrayRef({ValueRange(), ValueRange()}), /*branchWeights=*/ArrayRef()); return success(); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -32,6 +32,7 @@ #include "llvm/Support/SourceMgr.h" #include +#include using namespace mlir; using namespace mlir::LLVM; @@ -235,41 +236,27 @@ ArrayRef caseValues, BlockRange caseDestinations, ArrayRef caseOperands, ArrayRef branchWeights) { - SmallVector flattenedCaseOperands; - SmallVector caseOperandOffsets; - int32_t offset = 0; - for (ValueRange operands : caseOperands) { - flattenedCaseOperands.append(operands.begin(), operands.end()); - caseOperandOffsets.push_back(offset); - offset += operands.size(); - } ElementsAttr caseValuesAttr; if (!caseValues.empty()) caseValuesAttr = builder.getI32VectorAttr(caseValues); - ElementsAttr caseOperandOffsetsAttr; - if (!caseOperandOffsets.empty()) - caseOperandOffsetsAttr = builder.getI32VectorAttr(caseOperandOffsets); ElementsAttr weightsAttr; if (!branchWeights.empty()) weightsAttr = builder.getI32VectorAttr(llvm::to_vector<4>(branchWeights)); - build(builder, result, value, defaultOperands, flattenedCaseOperands, - caseValuesAttr, caseOperandOffsetsAttr, weightsAttr, defaultDestination, - caseDestinations); + build(builder, result, value, defaultOperands, caseOperands, caseValuesAttr, + weightsAttr, defaultDestination, caseDestinations); } /// ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)? /// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )? -static ParseResult -parseSwitchOpCases(OpAsmParser &parser, ElementsAttr &caseValues, - SmallVectorImpl &caseDestinations, - SmallVectorImpl &caseOperands, - SmallVectorImpl &caseOperandTypes, - ElementsAttr &caseOperandOffsets) { +static ParseResult parseSwitchOpCases( + OpAsmParser &parser, ElementsAttr &caseValues, + SmallVectorImpl &caseDestinations, + SmallVectorImpl> &caseOperands, + SmallVectorImpl> &caseOperandTypes) { SmallVector values; - SmallVector offsets; - int32_t value, offset = 0; + int32_t value = 0; do { OptionalParseResult integerParseResult = parser.parseOptionalInteger(value); if (values.empty() && !integerParseResult.hasValue()) @@ -281,32 +268,28 @@ Block *destination; SmallVector operands; + SmallVector operandTypes; if (parser.parseColon() || parser.parseSuccessor(destination)) return failure(); if (!parser.parseOptionalLParen()) { if (parser.parseRegionArgumentList(operands) || - parser.parseColonTypeList(caseOperandTypes) || parser.parseRParen()) + parser.parseColonTypeList(operandTypes) || parser.parseRParen()) return failure(); } caseDestinations.push_back(destination); - caseOperands.append(operands.begin(), operands.end()); - offsets.push_back(offset); - offset += operands.size(); + caseOperands.emplace_back(operands); + caseOperandTypes.emplace_back(operandTypes); } while (!parser.parseOptionalComma()); - Builder &builder = parser.getBuilder(); - caseValues = builder.getI32VectorAttr(values); - caseOperandOffsets = builder.getI32VectorAttr(offsets); - + caseValues = parser.getBuilder().getI32VectorAttr(values); return success(); } static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, ElementsAttr caseValues, SuccessorRange caseDestinations, - OperandRange caseOperands, - TypeRange caseOperandTypes, - ElementsAttr caseOperandOffsets) { + OperandRangeRange caseOperands, + TypeRangeRange caseOperandTypes) { if (!caseValues) return; @@ -317,7 +300,7 @@ p << " "; p << std::get<0>(i).getLimitedValue(); p << ": "; - p.printSuccessorAndUseList(std::get<1>(i), op.getCaseOperands(index++)); + p.printSuccessorAndUseList(std::get<1>(i), caseOperands[index++]); }, [&] { p << ','; @@ -341,28 +324,6 @@ return success(); } -OperandRange SwitchOp::getCaseOperands(unsigned index) { - return getCaseOperandsMutable(index); -} - -MutableOperandRange SwitchOp::getCaseOperandsMutable(unsigned index) { - MutableOperandRange caseOperands = caseOperandsMutable(); - if (!case_operand_offsets()) { - assert(caseOperands.size() == 0 && - "non-empty case operands must have offsets"); - return caseOperands; - } - - ElementsAttr offsets = case_operand_offsets().getValue(); - assert(index < offsets.size() && "invalid case operand offset index"); - - int64_t begin = offsets.getValue(index).cast().getInt(); - int64_t end = index + 1 == offsets.size() - ? caseOperands.size() - : offsets.getValue(index + 1).cast().getInt(); - return caseOperandsMutable().slice(begin, end - begin); -} - Optional SwitchOp::getMutableSuccessorOperands(unsigned index) { assert(index < getNumSuccessors() && "invalid successor index"); diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -27,6 +27,7 @@ #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/raw_ostream.h" +#include #include "mlir/Dialect/StandardOps/IR/OpsDialect.cpp.inc" @@ -2056,21 +2057,8 @@ DenseIntElementsAttr caseValues, BlockRange caseDestinations, ArrayRef caseOperands) { - SmallVector flattenedCaseOperands; - SmallVector caseOperandOffsets; - int32_t offset = 0; - for (ValueRange operands : caseOperands) { - flattenedCaseOperands.append(operands.begin(), operands.end()); - caseOperandOffsets.push_back(offset); - offset += operands.size(); - } - DenseIntElementsAttr caseOperandOffsetsAttr; - if (!caseOperandOffsets.empty()) - caseOperandOffsetsAttr = builder.getI32VectorAttr(caseOperandOffsets); - - build(builder, result, value, defaultOperands, flattenedCaseOperands, - caseValues, caseOperandOffsetsAttr, defaultDestination, - caseDestinations); + build(builder, result, value, defaultOperands, caseOperands, caseValues, + defaultDestination, caseDestinations); } void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, @@ -2089,16 +2077,14 @@ /// ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)? /// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )* -static ParseResult -parseSwitchOpCases(OpAsmParser &parser, Type &flagType, - Block *&defaultDestination, - SmallVectorImpl &defaultOperands, - SmallVectorImpl &defaultOperandTypes, - DenseIntElementsAttr &caseValues, - SmallVectorImpl &caseDestinations, - SmallVectorImpl &caseOperands, - SmallVectorImpl &caseOperandTypes, - DenseIntElementsAttr &caseOperandOffsets) { +static ParseResult parseSwitchOpCases( + OpAsmParser &parser, Type &flagType, Block *&defaultDestination, + SmallVectorImpl &defaultOperands, + SmallVectorImpl &defaultOperandTypes, + DenseIntElementsAttr &caseValues, + SmallVectorImpl &caseDestinations, + SmallVectorImpl> &caseOperands, + SmallVectorImpl> &caseOperandTypes) { if (failed(parser.parseKeyword("default")) || failed(parser.parseColon()) || failed(parser.parseSuccessor(defaultDestination))) return failure(); @@ -2110,9 +2096,7 @@ } SmallVector values; - SmallVector offsets; unsigned bitWidth = flagType.getIntOrFloatBitWidth(); - int64_t offset = 0; while (succeeded(parser.parseOptionalComma())) { int64_t value = 0; if (failed(parser.parseInteger(value))) @@ -2121,30 +2105,26 @@ Block *destination; SmallVector operands; + SmallVector operandTypes; if (failed(parser.parseColon()) || failed(parser.parseSuccessor(destination))) return failure(); if (succeeded(parser.parseOptionalLParen())) { if (failed(parser.parseRegionArgumentList(operands)) || - failed(parser.parseColonTypeList(caseOperandTypes)) || + failed(parser.parseColonTypeList(operandTypes)) || failed(parser.parseRParen())) return failure(); } caseDestinations.push_back(destination); - caseOperands.append(operands.begin(), operands.end()); - offsets.push_back(offset); - offset += operands.size(); + caseOperands.emplace_back(operands); + caseOperandTypes.emplace_back(operandTypes); } - if (values.empty()) - return success(); - - Builder &builder = parser.getBuilder(); - ShapedType caseValueType = - VectorType::get(static_cast(values.size()), flagType); - caseValues = DenseIntElementsAttr::get(caseValueType, values); - caseOperandOffsets = builder.getI32VectorAttr(offsets); - + if (!values.empty()) { + ShapedType caseValueType = + VectorType::get(static_cast(values.size()), flagType); + caseValues = DenseIntElementsAttr::get(caseValueType, values); + } return success(); } @@ -2152,8 +2132,7 @@ OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination, OperandRange defaultOperands, TypeRange defaultOperandTypes, DenseIntElementsAttr caseValues, SuccessorRange caseDestinations, - OperandRange caseOperands, TypeRange caseOperandTypes, - ElementsAttr caseOperandOffsets) { + OperandRangeRange caseOperands, TypeRangeRange caseOperandTypes) { p << " default: "; p.printSuccessorAndUseList(defaultDestination, defaultOperands); @@ -2166,7 +2145,7 @@ p << " "; p << caseValues.getValue(i).getLimitedValue(); p << ": "; - p.printSuccessorAndUseList(caseDestinations[i], op.getCaseOperands(i)); + p.printSuccessorAndUseList(caseDestinations[i], caseOperands[i]); } p.printNewline(); } @@ -2194,28 +2173,6 @@ return success(); } -OperandRange SwitchOp::getCaseOperands(unsigned index) { - return getCaseOperandsMutable(index); -} - -MutableOperandRange SwitchOp::getCaseOperandsMutable(unsigned index) { - MutableOperandRange caseOperands = caseOperandsMutable(); - if (!case_operand_offsets()) { - assert(caseOperands.size() == 0 && - "non-empty case operands must have offsets"); - return caseOperands; - } - - ElementsAttr offsets = case_operand_offsets().getValue(); - assert(index < offsets.size() && "invalid case operand offset index"); - - int64_t begin = offsets.getValue(index).cast().getInt(); - int64_t end = index + 1 == offsets.size() - ? caseOperands.size() - : offsets.getValue(index + 1).cast().getInt(); - return caseOperandsMutable().slice(begin, end - begin); -} - Optional SwitchOp::getMutableSuccessorOperands(unsigned index) { assert(index < getNumSuccessors() && "invalid successor index"); diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -12,9 +12,11 @@ //===----------------------------------------------------------------------===// #include "mlir/IR/OperationSupport.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpDefinition.h" #include "llvm/ADT/BitVector.h" +#include using namespace mlir; @@ -394,13 +396,38 @@ OperandRange::OperandRange(Operation *op) : OperandRange(op->getOpOperands().data(), op->getNumOperands()) {} -/// Return the operand index of the first element of this range. The range -/// must not be empty. unsigned OperandRange::getBeginOperandIndex() const { assert(!empty() && "range must not be empty"); return base->getOperandNumber(); } +OperandRangeRange OperandRange::split(ElementsAttr segmentSizes) const { + return OperandRangeRange(*this, segmentSizes); +} + +//===----------------------------------------------------------------------===// +// OperandRangeRange + +OperandRangeRange::OperandRangeRange(OperandRange operands, + Attribute operandSegments) + : OperandRangeRange(OwnerT(operands.getBase(), operandSegments), 0, + operandSegments.cast().size()) {} + +OperandRange OperandRangeRange::join() const { + const OwnerT &owner = getBase(); + auto sizeData = owner.second.cast().getValues(); + return OperandRange(owner.first, + std::accumulate(sizeData.begin(), sizeData.end(), 0)); +} + +OperandRange OperandRangeRange::dereference(const OwnerT &object, + ptrdiff_t index) { + auto sizeData = object.second.cast().getValues(); + uint32_t startIndex = + std::accumulate(sizeData.begin(), sizeData.begin() + index, 0); + return OperandRange(object.first + startIndex, *(sizeData.begin() + index)); +} + //===----------------------------------------------------------------------===// // MutableOperandRange @@ -419,7 +446,7 @@ /// Slice this range into a sub range, with the additional operand segment. MutableOperandRange MutableOperandRange::slice(unsigned subStart, unsigned subLen, - Optional segment) { + Optional segment) const { assert((subStart + subLen) <= length && "invalid sub-range"); MutableOperandRange subSlice(owner, start + subStart, subLen, operandSegments); @@ -475,6 +502,11 @@ return owner->getOperands().slice(start, length); } +MutableOperandRangeRange +MutableOperandRange::split(NamedAttribute segmentSizes) const { + return MutableOperandRangeRange(*this, segmentSizes); +} + /// Update the length of this range to the one provided. void MutableOperandRange::updateLength(unsigned newLength) { int32_t diff = int32_t(newLength) - int32_t(length); @@ -490,6 +522,35 @@ } } +//===----------------------------------------------------------------------===// +// MutableOperandRangeRange + +MutableOperandRangeRange::MutableOperandRangeRange( + const MutableOperandRange &operands, NamedAttribute operandSegmentAttr) + : MutableOperandRangeRange( + OwnerT(operands, operandSegmentAttr), 0, + operandSegmentAttr.second.cast().size()) {} + +MutableOperandRange MutableOperandRangeRange::join() const { + return getBase().first; +} + +MutableOperandRangeRange::operator OperandRangeRange() const { + return OperandRangeRange(getBase().first, + getBase().second.second.cast()); +} + +MutableOperandRange MutableOperandRangeRange::dereference(const OwnerT &object, + ptrdiff_t index) { + auto sizeData = + object.second.second.cast().getValues(); + uint32_t startIndex = + std::accumulate(sizeData.begin(), sizeData.begin() + index, 0); + return object.first.slice( + startIndex, *(sizeData.begin() + index), + MutableOperandRange::OperandSegment(index, object.second)); +} + //===----------------------------------------------------------------------===// // ValueRange diff --git a/mlir/lib/TableGen/Argument.cpp b/mlir/lib/TableGen/Argument.cpp --- a/mlir/lib/TableGen/Argument.cpp +++ b/mlir/lib/TableGen/Argument.cpp @@ -12,6 +12,10 @@ using namespace mlir; using namespace mlir::tblgen; +//===----------------------------------------------------------------------===// +// NamedTypeConstraint +//===----------------------------------------------------------------------===// + bool NamedTypeConstraint::hasPredicate() const { return !constraint.getPredicate().isNull(); } @@ -19,3 +23,7 @@ bool NamedTypeConstraint::isOptional() const { return constraint.isOptional(); } bool NamedTypeConstraint::isVariadic() const { return constraint.isVariadic(); } + +bool NamedTypeConstraint::isVariadicOfVariadic() const { + return constraint.isVariadicOfVariadic(); +} diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp --- a/mlir/lib/TableGen/Operator.cpp +++ b/mlir/lib/TableGen/Operator.cpp @@ -458,6 +458,13 @@ results.push_back({name, TypeConstraint(resultDef)}); if (!name.empty()) argumentsAndResultsIndex[name] = resultIndex(i); + + // We currently only support VariadicOfVariadic operands. + if (results.back().constraint.isVariadicOfVariadic()) { + PrintFatalError( + def.getLoc(), + "'VariadicOfVariadic' results are currently not supported"); + } } // Handle successors @@ -577,8 +584,7 @@ StringRef Operator::getAssemblyFormat() const { return TypeSwitch(def.getValueInit("assemblyFormat")) - .Case( - [&](auto *init) { return init->getValue(); }); + .Case([&](auto *init) { return init->getValue(); }); } void Operator::print(llvm::raw_ostream &os) const { diff --git a/mlir/lib/TableGen/Type.cpp b/mlir/lib/TableGen/Type.cpp --- a/mlir/lib/TableGen/Type.cpp +++ b/mlir/lib/TableGen/Type.cpp @@ -36,6 +36,15 @@ return def->isSubClassOf("Variadic"); } +bool TypeConstraint::isVariadicOfVariadic() const { + return def->isSubClassOf("VariadicOfVariadic"); +} + +StringRef TypeConstraint::getVariadicOfVariadicSegmentSizeAttr() const { + assert(isVariadicOfVariadic()); + return def->getValueAsString("segmentAttrName"); +} + // Returns the builder call for this constraint if this is a buildable type, // returns None otherwise. Optional TypeConstraint::getBuilderCall() const { diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1661,6 +1661,14 @@ let arguments = (ins Variadic:$operand); let assemblyFormat = [{ $operand `:` type($operand) attr-dict}]; } +def FormatVariadicOfVariadicOperand + : TEST_Op<"format_variadic_of_variadic_operand"> { + let arguments = (ins + VariadicOfVariadic:$operand, + I32ElementsAttr:$operand_segments + ); + let assemblyFormat = [{ $operand `:` type($operand) attr-dict}]; +} def FormatMultipleVariadicOperands : TEST_Op<"format_multiple_variadic_operands", [AttrSizedOperandSegments]> { diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir --- a/mlir/test/mlir-tblgen/op-format.mlir +++ b/mlir/test/mlir-tblgen/op-format.mlir @@ -151,6 +151,9 @@ // CHECK: test.format_variadic_operand %[[I64]], %[[I64]], %[[I64]] : i64, i64, i64 test.format_variadic_operand %i64, %i64, %i64 : i64, i64, i64 +// CHECK: test.format_variadic_of_variadic_operand (%[[I64]], %[[I64]]), (), (%[[I64]]) : (i64, i64), (), (i64) +test.format_variadic_of_variadic_operand (%i64, %i64), (), (%i64) : (i64, i64), (), (i64) + // CHECK: test.format_multiple_variadic_operands (%[[I64]], %[[I64]], %[[I64]]), (%[[I64]], %[[I32]] : i64, i32) test.format_multiple_variadic_operands (%i64, %i64, %i64), (%i64, %i32 : i64, i32) diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -24,6 +24,7 @@ #include "llvm/ADT/MapVector.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringSet.h" #include "llvm/Support/Path.h" #include "llvm/Support/Signals.h" #include "llvm/TableGen/Error.h" @@ -90,6 +91,23 @@ unsigned size = *(sizeAttrValues.begin() + index); return {start, size}; )"; +// The logic to calculate the actual value range for a declared operand +// of an op with variadic of variadic operands within the OpAdaptor. +// +// {0}: The name of the segment attribute. +// {1}: The index of the main operand. +const char *variadicOfVariadicAdaptorCalcCode = R"( + auto tblgenTmpOperands = getODSOperands({1}); + auto sizeAttrValues = {0}().getValues(); + auto sizeAttrIt = sizeAttrValues.begin(); + + ::llvm::SmallVector<::mlir::ValueRange> tblgenTmpOperandGroups; + for (int i = 0, e = ::llvm::size(sizeAttrValues); i < e; ++i, ++sizeAttrIt) {{ + tblgenTmpOperandGroups.push_back(tblgenTmpOperands.take_front(*sizeAttrIt)); + tblgenTmpOperands = tblgenTmpOperands.drop_front(*sizeAttrIt); + } + return tblgenTmpOperandGroups; +)"; // The logic to build a range of either operand or result values. // @@ -422,16 +440,20 @@ // Builds the parameter list for build() method of this op. This method writes // to `paramList` the comma-separated parameter list and updates // `resultTypeNames` with the names for parameters for specifying result - // types. The given `typeParamKind` and `attrParamKind` controls how result - // types and attributes are placed in the parameter list. + // types. `inferredAttributes` is populated with any attributes that are + // elided from the build list. The given `typeParamKind` and `attrParamKind` + // controls how result types and attributes are placed in the parameter list. void buildParamList(llvm::SmallVectorImpl ¶mList, + llvm::StringSet<> &inferredAttributes, SmallVectorImpl &resultTypeNames, TypeParamKind typeParamKind, AttrParamKind attrParamKind = AttrParamKind::WrappedAttr); // Adds op arguments and regions into operation state for build() methods. - void genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body, - bool isRawValueAttr = false); + void + genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body, + llvm::StringSet<> &inferredAttributes, + bool isRawValueAttr = false); // Generates canonicalizer declaration for the operation. void genCanonicalizerDecls(); @@ -956,7 +978,7 @@ // of ops, in particular for one-operand ops that may not have the // `getOperand(unsigned)` method. static void generateNamedOperandGetters(const Operator &op, Class &opClass, - StringRef sizeAttrInit, + bool isAdaptor, StringRef sizeAttrInit, StringRef rangeType, StringRef rangeBeginCall, StringRef rangeSizeCall, @@ -1011,6 +1033,20 @@ m->body() << " auto operands = getODSOperands(" << i << ");\n" << " return operands.empty() ? ::mlir::Value() : *operands.begin();"; + } else if (operand.isVariadicOfVariadic()) { + StringRef segmentAttr = + operand.constraint.getVariadicOfVariadicSegmentSizeAttr(); + if (isAdaptor) { + m = opClass.addMethodAndPrune("::llvm::SmallVector<::mlir::ValueRange>", + operand.name); + m->body() << llvm::formatv(variadicOfVariadicAdaptorCalcCode, + segmentAttr, i); + continue; + } + + m = opClass.addMethodAndPrune("::mlir::OperandRangeRange", operand.name); + m->body() << " return getODSOperands(" << i << ").split(" << segmentAttr + << "Attr());"; } else if (operand.isVariadic()) { m = opClass.addMethodAndPrune(rangeType, operand.name); m->body() << " return getODSOperands(" << i << ");"; @@ -1033,6 +1069,7 @@ generateNamedOperandGetters( op, opClass, + /*isAdaptor=*/false, /*sizeAttrInit=*/attrSizeInitCode, /*rangeType=*/"::mlir::Operation::operand_range", /*rangeBeginCall=*/"getOperation()->operand_begin()", @@ -1047,17 +1084,32 @@ const auto &operand = op.getOperand(i); if (operand.name.empty()) continue; - auto *m = opClass.addMethodAndPrune("::mlir::MutableOperandRange", + auto *m = opClass.addMethodAndPrune(operand.isVariadicOfVariadic() + ? "::mlir::MutableOperandRangeRange" + : "::mlir::MutableOperandRange", (operand.name + "Mutable").str()); auto &body = m->body(); body << " auto range = getODSOperandIndexAndLength(" << i << ");\n" - << " return ::mlir::MutableOperandRange(getOperation(), " + << " auto mutableRange = ::mlir::MutableOperandRange(getOperation(), " "range.first, range.second"; if (attrSizedOperands) body << ", ::mlir::MutableOperandRange::OperandSegment(" << i << "u, *getOperation()->getAttrDictionary().getNamed(" "operand_segment_sizesAttrName()))"; body << ");\n"; + + // If this operand is a nested variadic, we split the range into a + // MutableOperandRangeRange that provides a range over all of the + // sub-ranges. + if (operand.isVariadicOfVariadic()) { + body << " return " + "mutableRange.split(*(*this)->getAttrDictionary().getNamed(" + << operand.constraint.getVariadicOfVariadicSegmentSizeAttr() + << "AttrName()));\n"; + } else { + // Otherwise, we use the full range directly. + body << " return mutableRange;\n"; + } } } @@ -1211,7 +1263,9 @@ bool inferType) { llvm::SmallVector paramList; llvm::SmallVector resultNames; - buildParamList(paramList, resultNames, paramKind, attrType); + llvm::StringSet<> inferredAttributes; + buildParamList(paramList, inferredAttributes, resultNames, paramKind, + attrType); auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static, std::move(paramList)); @@ -1219,8 +1273,9 @@ if (!m) return; auto &body = m->body(); - genCodeForAddingArgAndRegionForBuilder( - body, /*isRawValueAttr=*/attrType == AttrParamKind::UnwrappedValue); + genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes, + /*isRawValueAttr=*/attrType == + AttrParamKind::UnwrappedValue); // Push all result types to the operation state @@ -1388,7 +1443,9 @@ void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() { llvm::SmallVector paramList; llvm::SmallVector resultNames; - buildParamList(paramList, resultNames, TypeParamKind::None); + llvm::StringSet<> inferredAttributes; + buildParamList(paramList, inferredAttributes, resultNames, + TypeParamKind::None); auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static, std::move(paramList)); @@ -1396,7 +1453,7 @@ if (!m) return; auto &body = m->body(); - genCodeForAddingArgAndRegionForBuilder(body); + genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes); auto numResults = op.getNumResults(); if (numResults == 0) @@ -1588,6 +1645,7 @@ } void OpEmitter::buildParamList(SmallVectorImpl ¶mList, + llvm::StringSet<> &inferredAttributes, SmallVectorImpl &resultTypeNames, TypeParamKind typeParamKind, AttrParamKind attrParamKind) { @@ -1626,10 +1684,6 @@ } // Add parameters for all arguments (operands and attributes). - - int numOperands = 0; - int numAttrs = 0; - int defaultValuedAttrStartIndex = op.getNumArgs(); if (attrParamKind == AttrParamKind::UnwrappedValue) { // Calculate the start index from which we can attach default values in the @@ -1655,54 +1709,68 @@ } } - for (int i = 0, e = op.getNumArgs(); i < e; ++i) { - auto argument = op.getArg(i); - if (argument.is()) { - const auto &operand = op.getOperand(numOperands); - StringRef type = - operand.isVariadic() ? "::mlir::ValueRange" : "::mlir::Value"; - OpMethodParameter::Property properties = OpMethodParameter::PP_None; - if (operand.isOptional()) - properties = OpMethodParameter::PP_Optional; + /// Collect any inferred attributes. + for (const NamedTypeConstraint &operand : op.getOperands()) { + if (operand.isVariadicOfVariadic()) { + inferredAttributes.insert( + operand.constraint.getVariadicOfVariadicSegmentSizeAttr()); + } + } - paramList.emplace_back(type, getArgumentName(op, numOperands), - properties); - ++numOperands; - } else { - const auto &namedAttr = op.getAttribute(numAttrs); - const auto &attr = namedAttr.attr; + for (unsigned i = 0, e = op.getNumArgs(), numOperands = 0; i < e; ++i) { + Argument arg = op.getArg(i); + if (const auto *operand = arg.dyn_cast()) { + StringRef type; + if (operand->isVariadicOfVariadic()) + type = "::llvm::ArrayRef<::mlir::ValueRange>"; + else if (operand->isVariadic()) + type = "::mlir::ValueRange"; + else + type = "::mlir::Value"; OpMethodParameter::Property properties = OpMethodParameter::PP_None; - if (attr.isOptional()) + if (operand->isOptional()) properties = OpMethodParameter::PP_Optional; + paramList.emplace_back(type, getArgumentName(op, numOperands++), + properties); + continue; + } + const NamedAttribute &namedAttr = *arg.get(); + const Attribute &attr = namedAttr.attr; - StringRef type; - switch (attrParamKind) { - case AttrParamKind::WrappedAttr: + // inferred attributes don't need to be added to the param list. + if (inferredAttributes.contains(namedAttr.name)) + continue; + + OpMethodParameter::Property properties = OpMethodParameter::PP_None; + if (attr.isOptional()) + properties = OpMethodParameter::PP_Optional; + + StringRef type; + switch (attrParamKind) { + case AttrParamKind::WrappedAttr: + type = attr.getStorageType(); + break; + case AttrParamKind::UnwrappedValue: + if (canUseUnwrappedRawValue(attr)) + type = attr.getReturnType(); + else type = attr.getStorageType(); - break; - case AttrParamKind::UnwrappedValue: - if (canUseUnwrappedRawValue(attr)) - type = attr.getReturnType(); - else - type = attr.getStorageType(); - break; - } + break; + } - std::string defaultValue; - // Attach default value if requested and possible. - if (attrParamKind == AttrParamKind::UnwrappedValue && - i >= defaultValuedAttrStartIndex) { - bool isString = attr.getReturnType() == "::llvm::StringRef"; - if (isString) - defaultValue.append("\""); - defaultValue += attr.getDefaultValue(); - if (isString) - defaultValue.append("\""); - } - paramList.emplace_back(type, namedAttr.name, defaultValue, properties); - ++numAttrs; + // Attach default value if requested and possible. + std::string defaultValue; + if (attrParamKind == AttrParamKind::UnwrappedValue && + i >= defaultValuedAttrStartIndex) { + bool isString = attr.getReturnType() == "::llvm::StringRef"; + if (isString) + defaultValue.append("\""); + defaultValue += attr.getDefaultValue(); + if (isString) + defaultValue.append("\""); } + paramList.emplace_back(type, namedAttr.name, defaultValue, properties); } /// Insert parameters for each successor. @@ -1719,12 +1787,31 @@ llvm::formatv("{0}Count", region.name).str()); } -void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body, - bool isRawValueAttr) { +void OpEmitter::genCodeForAddingArgAndRegionForBuilder( + OpMethodBody &body, llvm::StringSet<> &inferredAttributes, + bool isRawValueAttr) { // Push all operands to the result. for (int i = 0, e = op.getNumOperands(); i < e; ++i) { std::string argName = getArgumentName(op, i); - if (op.getOperand(i).isOptional()) + NamedTypeConstraint &operand = op.getOperand(i); + if (operand.constraint.isVariadicOfVariadic()) { + body << " for (::mlir::ValueRange range : " << argName << ")\n " + << builderOpState << ".addOperands(range);\n"; + + // Add the segment attribute. + body << " {\n" + << " SmallVector rangeSegments;\n" + << " for (::mlir::ValueRange range : " << argName << ")\n" + << " rangeSegments.push_back(range.size());\n" + << " " << builderOpState << ".addAttribute(" + << operand.constraint.getVariadicOfVariadicSegmentSizeAttr() + << "AttrName(" << builderOpState << ".name), " << odsBuilder + << ".getI32TensorAttr(rangeSegments));" + << " }\n"; + continue; + } + + if (operand.isOptional()) body << " if (" << argName << ")\n "; body << " " << builderOpState << ".addOperands(" << argName << ");\n"; } @@ -1736,12 +1823,24 @@ << ".name), " << "odsBuilder.getI32VectorAttr({"; interleaveComma(llvm::seq(0, op.getNumOperands()), body, [&](int i) { - if (op.getOperand(i).isOptional()) - body << "(" << getArgumentName(op, i) << " ? 1 : 0)"; - else if (op.getOperand(i).isVariadic()) - body << "static_cast(" << getArgumentName(op, i) << ".size())"; - else + const NamedTypeConstraint &operand = op.getOperand(i); + if (!operand.isVariableLength()) { body << "1"; + return; + } + + std::string operandName = getArgumentName(op, i); + if (operand.isOptional()) { + body << "(" << operandName << " ? 1 : 0)"; + } else if (operand.isVariadicOfVariadic()) { + body << llvm::formatv( + "static_cast(std::accumulate({0}.begin(), {0}.end(), 0, " + "[](int32_t curSum, ::mlir::ValueRange range) {{ return curSum + " + "range.size(); }))", + operandName); + } else { + body << "static_cast(" << getArgumentName(op, i) << ".size())"; + } }); body << "}));\n"; } @@ -1749,38 +1848,38 @@ // Push all attributes to the result. for (const auto &namedAttr : op.getAttributes()) { auto &attr = namedAttr.attr; - if (!attr.isDerivedAttr()) { - bool emitNotNullCheck = attr.isOptional(); - if (emitNotNullCheck) - body << formatv(" if ({0}) ", namedAttr.name) << "{\n"; - - if (isRawValueAttr && canUseUnwrappedRawValue(attr)) { - // If this is a raw value, then we need to wrap it in an Attribute - // instance. - FmtContext fctx; - fctx.withBuilder("odsBuilder"); - - std::string builderTemplate = - std::string(attr.getConstBuilderTemplate()); - - // For StringAttr, its constant builder call will wrap the input in - // quotes, which is correct for normal string literals, but incorrect - // here given we use function arguments. So we need to strip the - // wrapping quotes. - if (StringRef(builderTemplate).contains("\"$0\"")) - builderTemplate = replaceAllSubstrs(builderTemplate, "\"$0\"", "$0"); - - std::string value = - std::string(tgfmt(builderTemplate, &fctx, namedAttr.name)); - body << formatv(" {0}.addAttribute({1}AttrName({0}.name), {2});\n", - builderOpState, namedAttr.name, value); - } else { - body << formatv(" {0}.addAttribute({1}AttrName({0}.name), {1});\n", - builderOpState, namedAttr.name); - } - if (emitNotNullCheck) - body << " }\n"; + if (attr.isDerivedAttr() || inferredAttributes.contains(namedAttr.name)) + continue; + + bool emitNotNullCheck = attr.isOptional(); + if (emitNotNullCheck) + body << formatv(" if ({0}) ", namedAttr.name) << "{\n"; + + if (isRawValueAttr && canUseUnwrappedRawValue(attr)) { + // If this is a raw value, then we need to wrap it in an Attribute + // instance. + FmtContext fctx; + fctx.withBuilder("odsBuilder"); + + std::string builderTemplate = std::string(attr.getConstBuilderTemplate()); + + // For StringAttr, its constant builder call will wrap the input in + // quotes, which is correct for normal string literals, but incorrect + // here given we use function arguments. So we need to strip the + // wrapping quotes. + if (StringRef(builderTemplate).contains("\"$0\"")) + builderTemplate = replaceAllSubstrs(builderTemplate, "\"$0\"", "$0"); + + std::string value = + std::string(tgfmt(builderTemplate, &fctx, namedAttr.name)); + body << formatv(" {0}.addAttribute({1}AttrName({0}.name), {2});\n", + builderOpState, namedAttr.name, value); + } else { + body << formatv(" {0}.addAttribute({1}AttrName({0}.name), {1});\n", + builderOpState, namedAttr.name); } + if (emitNotNullCheck) + body << " }\n"; } // Create the correct number of regions. @@ -2430,7 +2529,8 @@ } std::string sizeAttrInit = formatv(adapterSegmentSizeAttrInitCode, "operand_segment_sizes"); - generateNamedOperandGetters(op, adaptor, sizeAttrInit, + generateNamedOperandGetters(op, adaptor, + /*isAdaptor=*/true, sizeAttrInit, /*rangeType=*/"::mlir::ValueRange", /*rangeBeginCall=*/"odsOperands.begin()", /*rangeSizeCall=*/"odsOperands.size()", diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -497,6 +497,7 @@ /// The set of attributes explicitly used within the format. SmallVector usedAttributes; + llvm::StringSet<> inferredAttributes; }; } // end anonymous namespace @@ -616,10 +617,41 @@ if (parser.parseOperand({0}RawOperands[0])) return ::mlir::failure(); )"; +/// The code snippet used to generate a parser call for a VariadicOfVariadic +/// operand. +/// +/// {0}: The name of the operand. +/// {1}: The name of segment size attribute. +const char *const variadicOfVariadicOperandParserCode = R"( + { + {0}OperandsLoc = parser.getCurrentLocation(); + ::llvm::SmallVector operandSizes; + int32_t curSize = 0; + do { + if (parser.parseOptionalLParen()) + break; + if (parser.parseOperandList({0}Operands) || parser.parseRParen()) + return ::mlir::failure(); + operandSizes.push_back({0}Operands.size() - curSize); + curSize = {0}Operands.size(); + } while (succeeded(parser.parseOptionalComma())); + result.addAttribute("{1}", + parser.getBuilder().getI32TensorAttr(operandSizes)); + } +)"; /// The code snippet used to generate a parser call for a type list. /// /// {0}: The name for the type list. +const char *const variadicOfVariadicTypeParserCode = R"( + do { + if (parser.parseOptionalLParen()) + break; + if (parser.parseOptionalRParen() && + (parser.parseTypeList({0}Types) || parser.parseRParen())) + return ::mlir::failure(); + } while (succeeded(parser.parseOptionalComma())); +)"; const char *const variadicTypeParserCode = R"( if (parser.parseTypeList({0}Types)) return ::mlir::failure(); @@ -758,6 +790,9 @@ namespace { /// The type of length for a given parse argument. enum class ArgumentLengthKind { + /// The argument is a variadic of a variadic, and may contain 0->N range + /// elements. + VariadicOfVariadic, /// The argument is variadic, and may contain 0->N elements. Variadic, /// The argument is optional, and may contain 0 or 1 elements. @@ -772,6 +807,8 @@ getArgumentLengthKind(const NamedTypeConstraint *var) { if (var->isOptional()) return ArgumentLengthKind::Optional; + if (var->isVariadicOfVariadic()) + return ArgumentLengthKind::VariadicOfVariadic; if (var->isVariadic()) return ArgumentLengthKind::Variadic; return ArgumentLengthKind::Single; @@ -924,7 +961,9 @@ } else if (auto *operand = dyn_cast(¶m)) { StringRef name = operand->getVar()->name; ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar()); - if (lengthKind == ArgumentLengthKind::Variadic) + if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) + body << llvm::formatv("{0}OperandGroups", name); + else if (lengthKind == ArgumentLengthKind::Variadic) body << llvm::formatv("{0}Operands", name); else if (lengthKind == ArgumentLengthKind::Optional) body << llvm::formatv("{0}Operand", name); @@ -951,7 +990,9 @@ } else if (auto *dir = dyn_cast(¶m)) { ArgumentLengthKind lengthKind; StringRef listName = getTypeListName(dir->getOperand(), lengthKind); - if (lengthKind == ArgumentLengthKind::Variadic) + if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) + body << llvm::formatv("{0}TypeGroups", listName); + else if (lengthKind == ArgumentLengthKind::Variadic) body << llvm::formatv("{0}Types", listName); else if (lengthKind == ArgumentLengthKind::Optional) body << llvm::formatv("{0}Type", listName); @@ -972,19 +1013,32 @@ // * Set the location of operand variables. for (Element ¶m : dir->getArguments()) { if (auto *operand = dyn_cast(¶m)) { - body << " " << operand->getVar()->name + auto *var = operand->getVar(); + body << " " << var->name << "OperandsLoc = parser.getCurrentLocation();\n"; - if (operand->getVar()->isOptional()) { + if (var->isOptional()) { body << llvm::formatv( " llvm::Optional<::mlir::OpAsmParser::OperandType> " "{0}Operand;\n", - operand->getVar()->name); + var->name); + } else if (var->isVariadicOfVariadic()) { + body << llvm::formatv(" " + "llvm::SmallVector> " + "{0}OperandGroups;\n", + var->name); } } else if (auto *dir = dyn_cast(¶m)) { ArgumentLengthKind lengthKind; StringRef listName = getTypeListName(dir->getOperand(), lengthKind); - if (lengthKind == ArgumentLengthKind::Optional) + if (lengthKind == ArgumentLengthKind::Optional) { body << llvm::formatv(" ::mlir::Type {0}Type;\n", listName); + } else if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) { + body << llvm::formatv( + " llvm::SmallVector> " + "{0}TypeGroups;\n", + listName); + } } else if (auto *dir = dyn_cast(¶m)) { Element *input = dir->getOperand(); if (auto *operand = dyn_cast(input)) { @@ -1028,11 +1082,21 @@ var->name); } else if (auto *operand = dyn_cast(¶m)) { const NamedTypeConstraint *var = operand->getVar(); - if (!var->isOptional()) - continue; - body << llvm::formatv(" if ({0}Operand.hasValue())\n" - " {0}Operands.push_back(*{0}Operand);\n", - var->name); + if (var->isOptional()) { + body << llvm::formatv(" if ({0}Operand.hasValue())\n" + " {0}Operands.push_back(*{0}Operand);\n", + var->name); + } else if (var->isVariadicOfVariadic()) { + body << llvm::formatv( + " llvm::SmallVector {0}OperandGroupSizes;\n" + " for (const auto &subRange : {0}OperandGroups) {{\n" + " {0}Operands.append(subRange.begin(), subRange.end());\n" + " {0}OperandGroupSizes.push_back(subRange.size());\n" + " }\n" + " result.addAttribute(\"{1}\", " + "parser.getBuilder().getI32TensorAttr({0}OperandGroupSizes));\n", + var->name, var->constraint.getVariadicOfVariadicSegmentSizeAttr()); + } } else if (auto *dir = dyn_cast(¶m)) { ArgumentLengthKind lengthKind; StringRef listName = getTypeListName(dir->getOperand(), lengthKind); @@ -1040,6 +1104,11 @@ body << llvm::formatv(" if ({0}Type)\n" " {0}Types.push_back({0}Type);\n", listName); + } else if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) { + body << llvm::formatv( + " for (const auto &subRange : {0}TypeGroups)\n" + " {0}Types.append(subRange.begin(), subRange.end());\n", + listName); } } } @@ -1229,7 +1298,11 @@ } else if (auto *operand = dyn_cast(element)) { ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar()); StringRef name = operand->getVar()->name; - if (lengthKind == ArgumentLengthKind::Variadic) + if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) + body << llvm::formatv( + variadicOfVariadicOperandParserCode, name, + operand->getVar()->constraint.getVariadicOfVariadicSegmentSizeAttr()); + else if (lengthKind == ArgumentLengthKind::Variadic) body << llvm::formatv(variadicOperandParserCode, name); else if (lengthKind == ArgumentLengthKind::Optional) body << llvm::formatv(optionalOperandParserCode, name); @@ -1281,7 +1354,9 @@ } else if (auto *dir = dyn_cast(element)) { ArgumentLengthKind lengthKind; StringRef listName = getTypeListName(dir->getOperand(), lengthKind); - if (lengthKind == ArgumentLengthKind::Variadic) + if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) + body << llvm::formatv(variadicOfVariadicTypeParserCode, listName); + else if (lengthKind == ArgumentLengthKind::Variadic) body << llvm::formatv(variadicTypeParserCode, listName); else if (lengthKind == ArgumentLengthKind::Optional) body << llvm::formatv(optionalTypeParserCode, listName); @@ -1575,6 +1650,10 @@ if (!fmt.allResultTypes && op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) body << "\"result_segment_sizes\", "; + if (!fmt.inferredAttributes.empty()) { + for (const auto &attr : fmt.inferredAttributes) + body << "\"" << attr.getKey() << "\", "; + } llvm::interleaveComma( fmt.usedAttributes, body, [&](const NamedAttribute *attr) { body << "\"" << attr->name << "\""; }); @@ -1693,6 +1772,8 @@ return body << "getOperation()->getResultTypes()"; auto *operand = dyn_cast(arg); auto *var = operand ? operand->getVar() : cast(arg)->getVar(); + if (var->isVariadicOfVariadic()) + return body << llvm::formatv("{0}().join().getTypes()", var->name); if (var->isVariadic()) return body << var->name << "().getTypes()"; if (var->isOptional()) @@ -1896,7 +1977,12 @@ else body << " p.printAttribute(" << var->name << "Attr());\n"; } else if (auto *operand = dyn_cast(element)) { - if (operand->getVar()->isOptional()) { + if (operand->getVar()->isVariadicOfVariadic()) { + body << " ::llvm::interleaveComma(" << operand->getVar()->name + << "(), p, [&](const auto &operands) { p << \"(\" << operands << " + "\")\"; });\n"; + + } else if (operand->getVar()->isOptional()) { body << " if (::mlir::Value value = " << operand->getVar()->name << "())\n" << " p << value;\n"; @@ -1926,6 +2012,15 @@ } else if (isa(element)) { body << " ::llvm::interleaveComma(getOperation()->getSuccessors(), p);\n"; } else if (auto *dir = dyn_cast(element)) { + if (auto *operand = dyn_cast(dir->getOperand())) { + if (operand->getVar()->isVariadicOfVariadic()) { + body << llvm::formatv(" ::llvm::interleaveComma({0}().getTypes(), p, " + "[&](::mlir::TypeRange types) {{ p << \"(\" << " + "types << \")\"; });\n", + operand->getVar()->name); + return; + } + } body << " p << "; genTypeOperandPrinter(dir->getOperand(), body) << ";\n"; } else if (auto *dir = dyn_cast(element)) { @@ -2449,6 +2544,16 @@ while (!iteratorStack.empty()) if (failed(verifyAttributes(loc, iteratorStack))) return ::mlir::failure(); + + // Check for VariadicOfVariadic variables. The segment attribute of those + // variables will be infered. + for (const NamedTypeConstraint *var : seenOperands) { + if (var->constraint.isVariadicOfVariadic()) { + fmt.inferredAttributes.insert( + var->constraint.getVariadicOfVariadicSegmentSizeAttr()); + } + } + return ::mlir::success(); } /// Verify the attribute elements at the back of the given stack of iterators.