diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md --- a/mlir/docs/OpDefinitions.md +++ b/mlir/docs/OpDefinitions.md @@ -229,6 +229,17 @@ indicate that all variable length operands have the same number of dynamic values. +#### VariadicOfVariadic operands + +To declare a variadic operand that has a variadic number of sub-ranges, wrap the +`TypeConstraint` for the operand with `VariadicOfVariadic<..., +"">`. + +The second field of the `VariadicOfVariadic` is the name of an `I32ElementsAttr` +argument that contains the sizes of the variadic sub-ranges. This attribute will +be used when determining the size of sub-ranges, or when updating the size of +sub-ranges. + #### Optional operands To declare an optional operand, wrap the `TypeConstraint` for the operand with @@ -717,6 +728,8 @@ - Single: `OpAsmParser::OperandType &` - Optional: `Optional &` - Variadic: `SmallVectorImpl &` + - VariadicOfVariadic: + `SmallVectorImpl> &` * Ref Directives - A reference directive is passed to the parser using the same mapping as the input operand. For example, a single region would be passed as a @@ -731,6 +744,7 @@ - Single: `Type &` - Optional: `Type &` - Variadic: `SmallVectorImpl &` + - VariadicOfVariadic: `SmallVectorImpl> &` * `attr-dict` Directive: `NamedAttrList &` When a variable is optional, the value should only be specified if the variable @@ -749,6 +763,7 @@ - Single: `Value` - Optional: `Value` - Variadic: `OperandRange` + - VariadicOfVariadic: `OperandRangeRange` * Ref Directives - A reference directive is passed to the printer using the same mapping as the input operand. For example, a single region would be passed as a @@ -763,6 +778,7 @@ - Single: `Type` - Optional: `Type` - Variadic: `TypeRange` + - VariadicOfVariadic: `TypeRangeRange` * `attr-dict` Directive: `DictionaryAttr` When a variable is optional, the provided value may be null. @@ -923,7 +939,7 @@ When this boolean field is set to `true`, it indicates that the op implements a `canonicalize` method for simple "matchAndRewrite" style canonicalization -patterns. If `hasCanonicalizer` is 0, then an implementation of +patterns. If `hasCanonicalizer` is 0, then an implementation of `::getCanonicalizationPatterns()` is implemented to call this function. ### `hasFolder` diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -701,23 +701,25 @@ def LLVM_SwitchOp : LLVM_TerminatorOp<"switch", [AttrSizedOperandSegments, DeclareOpInterfaceMethods, NoSideEffect]> { - let arguments = (ins I32:$value, - Variadic:$defaultOperands, - Variadic:$caseOperands, - OptionalAttr:$case_values, - OptionalAttr:$case_operand_offsets, - OptionalAttr:$branch_weights); + let arguments = (ins + I32:$value, + Variadic:$defaultOperands, + VariadicOfVariadic:$caseOperands, + OptionalAttr:$case_values, + ElementsAttr:$case_operand_segments, + OptionalAttr:$branch_weights + ); let successors = (successor - AnySuccessor:$defaultDestination, - VariadicSuccessor:$caseDestinations); + AnySuccessor:$defaultDestination, + VariadicSuccessor:$caseDestinations + ); let verifier = [{ return ::verify(*this); }]; let assemblyFormat = [{ $value `,` $defaultDestination (`(` $defaultOperands^ `:` type($defaultOperands) `)`)? `[` `\n` custom($case_values, $caseDestinations, - $caseOperands, type($caseOperands), - $case_operand_offsets) `]` + $caseOperands, type($caseOperands)) `]` attr-dict }]; @@ -734,11 +736,15 @@ let extraClassDeclaration = [{ /// Return the operands for the case destination block at the given index. - OperandRange getCaseOperands(unsigned index); + OperandRange getCaseOperands(unsigned index) { + return caseOperands()[index]; + } /// Return a mutable range of operands for the case destination block at the /// given index. - MutableOperandRange getCaseOperandsMutable(unsigned index); + MutableOperandRange getCaseOperandsMutable(unsigned index) { + return caseOperandsMutable()[index]; + } }]; } diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -1812,14 +1812,17 @@ ``` }]; - let arguments = (ins AnyInteger:$flag, - Variadic:$defaultOperands, - Variadic:$caseOperands, - OptionalAttr:$case_values, - OptionalAttr:$case_operand_offsets); + let arguments = (ins + AnyInteger:$flag, + Variadic:$defaultOperands, + VariadicOfVariadic:$caseOperands, + OptionalAttr:$case_values, + I32ElementsAttr:$case_operand_segments + ); let successors = (successor - AnySuccessor:$defaultDestination, - VariadicSuccessor:$caseDestinations); + AnySuccessor:$defaultDestination, + VariadicSuccessor:$caseDestinations + ); let builders = [ OpBuilder<(ins "Value":$flag, "Block *":$defaultDestination, @@ -1849,19 +1852,22 @@ $case_values, $caseDestinations, $caseOperands, - type($caseOperands), - $case_operand_offsets) + type($caseOperands)) `]` attr-dict }]; let extraClassDeclaration = [{ /// Return the operands for the case destination block at the given index. - OperandRange getCaseOperands(unsigned index); + OperandRange getCaseOperands(unsigned index) { + return caseOperands()[index]; + } /// Return a mutable range of operands for the case destination block at the /// given index. - MutableOperandRange getCaseOperandsMutable(unsigned index); + MutableOperandRange getCaseOperandsMutable(unsigned index) { + return caseOperandsMutable()[index]; + } }]; let hasCanonicalizer = 1; diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -324,6 +324,16 @@ Type baseType = type; } +// A nested variadic type constraint. It expands to zero or more variadic ranges +// of the base type. This class is used for supporting variadic operands and +// results. `variadicSegmentAttrName` should correspond to the name of an +// I32ElementsAttr argument that provides the sizes of the inner variadic +// operand groups. +class VariadicOfVariadic + : Variadic { + string segmentAttrName = variadicSegmentAttrName; +} + // An optional type constraint. It expands to either zero or one of the base // type. This class is used for supporting optional operands/results. class Optional : TypeConstraint { diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -267,6 +267,9 @@ LogicalResult verifyOneSuccessor(Operation *op); LogicalResult verifyNSuccessors(Operation *op, unsigned numSuccessors); LogicalResult verifyAtLeastNSuccessors(Operation *op, unsigned numSuccessors); +LogicalResult verifyValueSizeAttr(Operation *op, StringRef attrName, + StringRef valueGroupName, + size_t expectedCount); LogicalResult verifyOperandSizeAttr(Operation *op, StringRef sizeAttrName); LogicalResult verifyResultSizeAttr(Operation *op, StringRef sizeAttrName); LogicalResult verifyNoRegionArguments(Operation *op); diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -36,12 +36,14 @@ class Dialect; class DictionaryAttr; class ElementsAttr; +class MutableOperandRangeRange; class Operation; struct OperationState; class OpAsmParser; class OpAsmParserResult; class OpAsmPrinter; class OperandRange; +class OperandRangeRange; class OpFoldResult; class ParseResult; class Pattern; @@ -727,6 +729,10 @@ /// must not be empty. unsigned getBeginOperandIndex() const; + /// Split this range into a set of contiguous subranges using the given + /// elements attribute, which contains the sizes of the sub ranges. + OperandRangeRange split(ElementsAttr segmentSizes) const; + private: /// See `llvm::detail::indexed_accessor_range_base` for details. static OpOperand *offset_base(OpOperand *object, ptrdiff_t index) { @@ -741,6 +747,42 @@ friend RangeBaseT; }; +//===----------------------------------------------------------------------===// +// OperandRangeRange + +/// This class represents a contiguous range of operand ranges, e.g. from a +/// VariadicOfVariadic operand group. +class OperandRangeRange final + : public llvm::indexed_accessor_range< + OperandRangeRange, std::pair, OperandRange, + OperandRange, OperandRange> { + using OwnerT = std::pair; + using RangeBaseT = + llvm::indexed_accessor_range; + +public: + using RangeBaseT::RangeBaseT; + + /// Returns the range of types of the values within this range. + TypeRangeRange getTypes() const { return TypeRangeRange(*this); } + auto getType() const { return getTypes(); } + + /// Construct a range given a parent set of operands, and an I32 elements + /// attribute containing the sizes of the sub ranges. + OperandRangeRange(OperandRange operands, Attribute operandSegments); + + /// Flatten all of the sub ranges into a single contiguous operand range. + OperandRange join() const; + +private: + /// See `llvm::indexed_accessor_range` for details. + static OperandRange dereference(const OwnerT &object, ptrdiff_t index); + + /// Allow access to `dereference_iterator`. + friend RangeBaseT; +}; + //===----------------------------------------------------------------------===// // MutableOperandRange @@ -761,8 +803,9 @@ MutableOperandRange(Operation *owner); /// Slice this range into a sub range, with the additional operand segment. - MutableOperandRange slice(unsigned subStart, unsigned subLen, - Optional segment = llvm::None); + MutableOperandRange + slice(unsigned subStart, unsigned subLen, + Optional segment = llvm::None) const; /// Append the given values to the range. void append(ValueRange values); @@ -782,12 +825,19 @@ /// Returns the current size of the range. unsigned size() const { return length; } + /// Returns if the current range is empty. + bool empty() const { return size() == 0; } + /// Allow implicit conversion to an OperandRange. operator OperandRange() const; /// Returns the owning operation. Operation *getOwner() const { return owner; } + /// Split this range into a set of contiguous subranges using the given + /// elements attribute, which contains the sizes of the sub ranges. + MutableOperandRangeRange split(NamedAttribute segmentSizes) const; + private: /// Update the length of this range to the one provided. void updateLength(unsigned newLength); @@ -801,7 +851,46 @@ /// Optional set of operand segments that should be updated when mutating the /// length of this range. - SmallVector, 1> operandSegments; + SmallVector operandSegments; +}; + +//===----------------------------------------------------------------------===// +// MutableOperandRangeRange + +/// This class represents a contiguous range of mutable operand ranges, e.g. +/// from a VariadicOfVariadic operand group. +class MutableOperandRangeRange final + : public llvm::indexed_accessor_range< + MutableOperandRangeRange, + std::pair, MutableOperandRange, + MutableOperandRange, MutableOperandRange> { + using OwnerT = std::pair; + using RangeBaseT = + llvm::indexed_accessor_range; + +public: + using RangeBaseT::RangeBaseT; + + /// Construct a range given a parent set of operands, and an I32 tensor + /// elements attribute containing the sizes of the sub ranges. + MutableOperandRangeRange(const MutableOperandRange &operands, + NamedAttribute operandSegmentAttr); + + /// Flatten all of the sub ranges into a single contiguous mutable operand + /// range. + MutableOperandRange join() const; + + /// Allow implicit conversion to an OperandRangeRange. + operator OperandRangeRange() const; + +private: + /// See `llvm::indexed_accessor_range` for details. + static MutableOperandRange dereference(const OwnerT &object, ptrdiff_t index); + + /// Allow access to `dereference_iterator`. + friend RangeBaseT; }; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/TypeRange.h b/mlir/include/mlir/IR/TypeRange.h --- a/mlir/include/mlir/IR/TypeRange.h +++ b/mlir/include/mlir/IR/TypeRange.h @@ -16,6 +16,7 @@ #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" #include "llvm/ADT/PointerUnion.h" +#include "llvm/ADT/Sequence.h" namespace mlir { class OperandRange; @@ -88,6 +89,35 @@ return os; } +//===----------------------------------------------------------------------===// +// TypeRangeRange + +using TypeRangeRangeIterator = + llvm::mapped_iterator::iterator, + std::function>; + +/// This class provides an abstraction for a range of TypeRange. This is useful +/// when accessing the types of a range of ranges, such as when using +/// OperandRangeRange. +class TypeRangeRange : public llvm::iterator_range { +public: + template + TypeRangeRange(const RangeT &range) + : TypeRangeRange(llvm::seq(0, range.size()), range) {} + +private: + template + TypeRangeRange(llvm::iota_range sizeRange, const RangeT &range) + : llvm::iterator_range( + {sizeRange.begin(), getRangeFn(range)}, + {sizeRange.end(), nullptr}) {} + + template + static std::function getRangeFn(const RangeT &range) { + return [=](unsigned index) -> TypeRange { return TypeRange(range[index]); }; + } +}; + //===----------------------------------------------------------------------===// // ValueTypeRange diff --git a/mlir/include/mlir/TableGen/Argument.h b/mlir/include/mlir/TableGen/Argument.h --- a/mlir/include/mlir/TableGen/Argument.h +++ b/mlir/include/mlir/TableGen/Argument.h @@ -48,6 +48,8 @@ bool isOptional() const; // Returns true if this operand/result is variadic. bool isVariadic() const; + // Returns true if this operand/result is a variadic of a variadic constraint. + bool isVariadicOfVariadic() const; // Returns true if this is a variable length type constraint. This is either // variadic or optional. bool isVariableLength() const { return isOptional() || isVariadic(); } diff --git a/mlir/include/mlir/TableGen/Type.h b/mlir/include/mlir/TableGen/Type.h --- a/mlir/include/mlir/TableGen/Type.h +++ b/mlir/include/mlir/TableGen/Type.h @@ -40,6 +40,13 @@ // Returns true if this is a variadic type constraint. bool isVariadic() const; + // Returns true if this is a nested variadic type constraint. + bool isVariadicOfVariadic() const; + + // Return the segment size attribute used if this is a variadic of variadic + // constraint. Asserts isVariadicOfVariadic() is true. + StringRef getVariadicOfVariadicSegmentSizeAttr() const; + // Returns true if this is a variable length type constraint. This is either // variadic or optional. bool isVariableLength() const { return isOptional() || isVariadic(); } diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -520,7 +520,7 @@ /*defaultOperands=*/ValueRange(), /*caseValues=*/caseValues, /*caseDestinations=*/caseDest, - /*caseOperands=*/ArrayRef(), + /*caseOperands=*/ArrayRef({ValueRange(), ValueRange()}), /*branchWeights=*/ArrayRef()); return success(); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -32,6 +32,7 @@ #include "llvm/Support/SourceMgr.h" #include +#include using namespace mlir; using namespace mlir::LLVM; @@ -235,41 +236,27 @@ ArrayRef caseValues, BlockRange caseDestinations, ArrayRef caseOperands, ArrayRef branchWeights) { - SmallVector flattenedCaseOperands; - SmallVector caseOperandOffsets; - int32_t offset = 0; - for (ValueRange operands : caseOperands) { - flattenedCaseOperands.append(operands.begin(), operands.end()); - caseOperandOffsets.push_back(offset); - offset += operands.size(); - } ElementsAttr caseValuesAttr; if (!caseValues.empty()) caseValuesAttr = builder.getI32VectorAttr(caseValues); - ElementsAttr caseOperandOffsetsAttr; - if (!caseOperandOffsets.empty()) - caseOperandOffsetsAttr = builder.getI32VectorAttr(caseOperandOffsets); ElementsAttr weightsAttr; if (!branchWeights.empty()) weightsAttr = builder.getI32VectorAttr(llvm::to_vector<4>(branchWeights)); - build(builder, result, value, defaultOperands, flattenedCaseOperands, - caseValuesAttr, caseOperandOffsetsAttr, weightsAttr, defaultDestination, - caseDestinations); + build(builder, result, value, defaultOperands, caseOperands, caseValuesAttr, + weightsAttr, defaultDestination, caseDestinations); } /// ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)? /// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )? -static ParseResult -parseSwitchOpCases(OpAsmParser &parser, ElementsAttr &caseValues, - SmallVectorImpl &caseDestinations, - SmallVectorImpl &caseOperands, - SmallVectorImpl &caseOperandTypes, - ElementsAttr &caseOperandOffsets) { +static ParseResult parseSwitchOpCases( + OpAsmParser &parser, ElementsAttr &caseValues, + SmallVectorImpl &caseDestinations, + SmallVectorImpl> &caseOperands, + SmallVectorImpl> &caseOperandTypes) { SmallVector values; - SmallVector offsets; - int32_t value, offset = 0; + int32_t value = 0; do { OptionalParseResult integerParseResult = parser.parseOptionalInteger(value); if (values.empty() && !integerParseResult.hasValue()) @@ -281,32 +268,28 @@ Block *destination; SmallVector operands; + SmallVector operandTypes; if (parser.parseColon() || parser.parseSuccessor(destination)) return failure(); if (!parser.parseOptionalLParen()) { if (parser.parseRegionArgumentList(operands) || - parser.parseColonTypeList(caseOperandTypes) || parser.parseRParen()) + parser.parseColonTypeList(operandTypes) || parser.parseRParen()) return failure(); } caseDestinations.push_back(destination); - caseOperands.append(operands.begin(), operands.end()); - offsets.push_back(offset); - offset += operands.size(); + caseOperands.emplace_back(operands); + caseOperandTypes.emplace_back(operandTypes); } while (!parser.parseOptionalComma()); - Builder &builder = parser.getBuilder(); - caseValues = builder.getI32VectorAttr(values); - caseOperandOffsets = builder.getI32VectorAttr(offsets); - + caseValues = parser.getBuilder().getI32VectorAttr(values); return success(); } static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, ElementsAttr caseValues, SuccessorRange caseDestinations, - OperandRange caseOperands, - TypeRange caseOperandTypes, - ElementsAttr caseOperandOffsets) { + OperandRangeRange caseOperands, + TypeRangeRange caseOperandTypes) { if (!caseValues) return; @@ -317,7 +300,7 @@ p << " "; p << std::get<0>(i).getLimitedValue(); p << ": "; - p.printSuccessorAndUseList(std::get<1>(i), op.getCaseOperands(index++)); + p.printSuccessorAndUseList(std::get<1>(i), caseOperands[index++]); }, [&] { p << ','; @@ -341,28 +324,6 @@ return success(); } -OperandRange SwitchOp::getCaseOperands(unsigned index) { - return getCaseOperandsMutable(index); -} - -MutableOperandRange SwitchOp::getCaseOperandsMutable(unsigned index) { - MutableOperandRange caseOperands = caseOperandsMutable(); - if (!case_operand_offsets()) { - assert(caseOperands.size() == 0 && - "non-empty case operands must have offsets"); - return caseOperands; - } - - ElementsAttr offsets = case_operand_offsets().getValue(); - assert(index < offsets.size() && "invalid case operand offset index"); - - int64_t begin = offsets.getValue(index).cast().getInt(); - int64_t end = index + 1 == offsets.size() - ? caseOperands.size() - : offsets.getValue(index + 1).cast().getInt(); - return caseOperandsMutable().slice(begin, end - begin); -} - Optional SwitchOp::getMutableSuccessorOperands(unsigned index) { assert(index < getNumSuccessors() && "invalid successor index"); diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -28,6 +28,7 @@ #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/raw_ostream.h" +#include #include "mlir/Dialect/StandardOps/IR/OpsDialect.cpp.inc" @@ -2130,21 +2131,8 @@ DenseIntElementsAttr caseValues, BlockRange caseDestinations, ArrayRef caseOperands) { - SmallVector flattenedCaseOperands; - SmallVector caseOperandOffsets; - int32_t offset = 0; - for (ValueRange operands : caseOperands) { - flattenedCaseOperands.append(operands.begin(), operands.end()); - caseOperandOffsets.push_back(offset); - offset += operands.size(); - } - DenseIntElementsAttr caseOperandOffsetsAttr; - if (!caseOperandOffsets.empty()) - caseOperandOffsetsAttr = builder.getI32VectorAttr(caseOperandOffsets); - - build(builder, result, value, defaultOperands, flattenedCaseOperands, - caseValues, caseOperandOffsetsAttr, defaultDestination, - caseDestinations); + build(builder, result, value, defaultOperands, caseOperands, caseValues, + defaultDestination, caseDestinations); } void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value, @@ -2163,16 +2151,14 @@ /// ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)? /// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )* -static ParseResult -parseSwitchOpCases(OpAsmParser &parser, Type &flagType, - Block *&defaultDestination, - SmallVectorImpl &defaultOperands, - SmallVectorImpl &defaultOperandTypes, - DenseIntElementsAttr &caseValues, - SmallVectorImpl &caseDestinations, - SmallVectorImpl &caseOperands, - SmallVectorImpl &caseOperandTypes, - DenseIntElementsAttr &caseOperandOffsets) { +static ParseResult parseSwitchOpCases( + OpAsmParser &parser, Type &flagType, Block *&defaultDestination, + SmallVectorImpl &defaultOperands, + SmallVectorImpl &defaultOperandTypes, + DenseIntElementsAttr &caseValues, + SmallVectorImpl &caseDestinations, + SmallVectorImpl> &caseOperands, + SmallVectorImpl> &caseOperandTypes) { if (failed(parser.parseKeyword("default")) || failed(parser.parseColon()) || failed(parser.parseSuccessor(defaultDestination))) return failure(); @@ -2184,9 +2170,7 @@ } SmallVector values; - SmallVector offsets; unsigned bitWidth = flagType.getIntOrFloatBitWidth(); - int64_t offset = 0; while (succeeded(parser.parseOptionalComma())) { int64_t value = 0; if (failed(parser.parseInteger(value))) @@ -2195,30 +2179,26 @@ Block *destination; SmallVector operands; + SmallVector operandTypes; if (failed(parser.parseColon()) || failed(parser.parseSuccessor(destination))) return failure(); if (succeeded(parser.parseOptionalLParen())) { if (failed(parser.parseRegionArgumentList(operands)) || - failed(parser.parseColonTypeList(caseOperandTypes)) || + failed(parser.parseColonTypeList(operandTypes)) || failed(parser.parseRParen())) return failure(); } caseDestinations.push_back(destination); - caseOperands.append(operands.begin(), operands.end()); - offsets.push_back(offset); - offset += operands.size(); + caseOperands.emplace_back(operands); + caseOperandTypes.emplace_back(operandTypes); } - if (values.empty()) - return success(); - - Builder &builder = parser.getBuilder(); - ShapedType caseValueType = - VectorType::get(static_cast(values.size()), flagType); - caseValues = DenseIntElementsAttr::get(caseValueType, values); - caseOperandOffsets = builder.getI32VectorAttr(offsets); - + if (!values.empty()) { + ShapedType caseValueType = + VectorType::get(static_cast(values.size()), flagType); + caseValues = DenseIntElementsAttr::get(caseValueType, values); + } return success(); } @@ -2226,8 +2206,7 @@ OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination, OperandRange defaultOperands, TypeRange defaultOperandTypes, DenseIntElementsAttr caseValues, SuccessorRange caseDestinations, - OperandRange caseOperands, TypeRange caseOperandTypes, - ElementsAttr caseOperandOffsets) { + OperandRangeRange caseOperands, TypeRangeRange caseOperandTypes) { p << " default: "; p.printSuccessorAndUseList(defaultDestination, defaultOperands); @@ -2240,7 +2219,7 @@ p << " "; p << caseValues.getValue(i).getLimitedValue(); p << ": "; - p.printSuccessorAndUseList(caseDestinations[i], op.getCaseOperands(i)); + p.printSuccessorAndUseList(caseDestinations[i], caseOperands[i]); } p.printNewline(); } @@ -2268,28 +2247,6 @@ return success(); } -OperandRange SwitchOp::getCaseOperands(unsigned index) { - return getCaseOperandsMutable(index); -} - -MutableOperandRange SwitchOp::getCaseOperandsMutable(unsigned index) { - MutableOperandRange caseOperands = caseOperandsMutable(); - if (!case_operand_offsets()) { - assert(caseOperands.size() == 0 && - "non-empty case operands must have offsets"); - return caseOperands; - } - - ElementsAttr offsets = case_operand_offsets().getValue(); - assert(index < offsets.size() && "invalid case operand offset index"); - - int64_t begin = offsets.getValue(index).cast().getInt(); - int64_t end = index + 1 == offsets.size() - ? caseOperands.size() - : offsets.getValue(index + 1).cast().getInt(); - return caseOperandsMutable().slice(begin, end - begin); -} - Optional SwitchOp::getMutableSuccessorOperands(unsigned index) { assert(index < getNumSuccessors() && "invalid successor index"); diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -996,16 +996,19 @@ return success(); } -static LogicalResult verifyValueSizeAttr(Operation *op, StringRef attrName, - bool isOperand) { +LogicalResult OpTrait::impl::verifyValueSizeAttr(Operation *op, + StringRef attrName, + StringRef valueGroupName, + size_t expectedCount) { auto sizeAttr = op->getAttrOfType(attrName); if (!sizeAttr) - return op->emitOpError("requires 1D vector attribute '") << attrName << "'"; + return op->emitOpError("requires 1D i32 elements attribute '") + << attrName << "'"; - auto sizeAttrType = sizeAttr.getType().dyn_cast(); - if (!sizeAttrType || sizeAttrType.getRank() != 1 || + auto sizeAttrType = sizeAttr.getType(); + if (sizeAttrType.getRank() != 1 || !sizeAttrType.getElementType().isInteger(32)) - return op->emitOpError("requires 1D vector of i32 attribute '") + return op->emitOpError("requires 1D i32 elements attribute '") << attrName << "'"; if (llvm::any_of(sizeAttr.getIntValues(), [](const APInt &element) { @@ -1018,25 +1021,22 @@ sizeAttr.begin(), sizeAttr.end(), 0, [](unsigned all, APInt one) { return all + one.getZExtValue(); }); - if (isOperand && totalCount != op->getNumOperands()) - return op->emitOpError("operand count (") - << op->getNumOperands() << ") does not match with the total size (" - << totalCount << ") specified in attribute '" << attrName << "'"; - else if (!isOperand && totalCount != op->getNumResults()) - return op->emitOpError("result count (") - << op->getNumResults() << ") does not match with the total size (" - << totalCount << ") specified in attribute '" << attrName << "'"; + if (totalCount != expectedCount) + return op->emitOpError() + << valueGroupName << " count (" << expectedCount + << ") does not match with the total size (" << totalCount + << ") specified in attribute '" << attrName << "'"; return success(); } LogicalResult OpTrait::impl::verifyOperandSizeAttr(Operation *op, StringRef attrName) { - return verifyValueSizeAttr(op, attrName, /*isOperand=*/true); + return verifyValueSizeAttr(op, attrName, "operand", op->getNumOperands()); } LogicalResult OpTrait::impl::verifyResultSizeAttr(Operation *op, StringRef attrName) { - return verifyValueSizeAttr(op, attrName, /*isOperand=*/false); + return verifyValueSizeAttr(op, attrName, "result", op->getNumResults()); } LogicalResult OpTrait::impl::verifyNoRegionArguments(Operation *op) { diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -12,9 +12,11 @@ //===----------------------------------------------------------------------===// #include "mlir/IR/OperationSupport.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpDefinition.h" #include "llvm/ADT/BitVector.h" +#include using namespace mlir; @@ -394,13 +396,38 @@ OperandRange::OperandRange(Operation *op) : OperandRange(op->getOpOperands().data(), op->getNumOperands()) {} -/// Return the operand index of the first element of this range. The range -/// must not be empty. unsigned OperandRange::getBeginOperandIndex() const { assert(!empty() && "range must not be empty"); return base->getOperandNumber(); } +OperandRangeRange OperandRange::split(ElementsAttr segmentSizes) const { + return OperandRangeRange(*this, segmentSizes); +} + +//===----------------------------------------------------------------------===// +// OperandRangeRange + +OperandRangeRange::OperandRangeRange(OperandRange operands, + Attribute operandSegments) + : OperandRangeRange(OwnerT(operands.getBase(), operandSegments), 0, + operandSegments.cast().size()) {} + +OperandRange OperandRangeRange::join() const { + const OwnerT &owner = getBase(); + auto sizeData = owner.second.cast().getValues(); + return OperandRange(owner.first, + std::accumulate(sizeData.begin(), sizeData.end(), 0)); +} + +OperandRange OperandRangeRange::dereference(const OwnerT &object, + ptrdiff_t index) { + auto sizeData = object.second.cast().getValues(); + uint32_t startIndex = + std::accumulate(sizeData.begin(), sizeData.begin() + index, 0); + return OperandRange(object.first + startIndex, *(sizeData.begin() + index)); +} + //===----------------------------------------------------------------------===// // MutableOperandRange @@ -419,7 +446,7 @@ /// Slice this range into a sub range, with the additional operand segment. MutableOperandRange MutableOperandRange::slice(unsigned subStart, unsigned subLen, - Optional segment) { + Optional segment) const { assert((subStart + subLen) <= length && "invalid sub-range"); MutableOperandRange subSlice(owner, start + subStart, subLen, operandSegments); @@ -475,6 +502,11 @@ return owner->getOperands().slice(start, length); } +MutableOperandRangeRange +MutableOperandRange::split(NamedAttribute segmentSizes) const { + return MutableOperandRangeRange(*this, segmentSizes); +} + /// Update the length of this range to the one provided. void MutableOperandRange::updateLength(unsigned newLength) { int32_t diff = int32_t(newLength) - int32_t(length); @@ -490,6 +522,35 @@ } } +//===----------------------------------------------------------------------===// +// MutableOperandRangeRange + +MutableOperandRangeRange::MutableOperandRangeRange( + const MutableOperandRange &operands, NamedAttribute operandSegmentAttr) + : MutableOperandRangeRange( + OwnerT(operands, operandSegmentAttr), 0, + operandSegmentAttr.second.cast().size()) {} + +MutableOperandRange MutableOperandRangeRange::join() const { + return getBase().first; +} + +MutableOperandRangeRange::operator OperandRangeRange() const { + return OperandRangeRange(getBase().first, + getBase().second.second.cast()); +} + +MutableOperandRange MutableOperandRangeRange::dereference(const OwnerT &object, + ptrdiff_t index) { + auto sizeData = + object.second.second.cast().getValues(); + uint32_t startIndex = + std::accumulate(sizeData.begin(), sizeData.begin() + index, 0); + return object.first.slice( + startIndex, *(sizeData.begin() + index), + MutableOperandRange::OperandSegment(index, object.second)); +} + //===----------------------------------------------------------------------===// // ValueRange diff --git a/mlir/lib/TableGen/Argument.cpp b/mlir/lib/TableGen/Argument.cpp --- a/mlir/lib/TableGen/Argument.cpp +++ b/mlir/lib/TableGen/Argument.cpp @@ -12,6 +12,10 @@ using namespace mlir; using namespace mlir::tblgen; +//===----------------------------------------------------------------------===// +// NamedTypeConstraint +//===----------------------------------------------------------------------===// + bool NamedTypeConstraint::hasPredicate() const { return !constraint.getPredicate().isNull(); } @@ -19,3 +23,7 @@ bool NamedTypeConstraint::isOptional() const { return constraint.isOptional(); } bool NamedTypeConstraint::isVariadic() const { return constraint.isVariadic(); } + +bool NamedTypeConstraint::isVariadicOfVariadic() const { + return constraint.isVariadicOfVariadic(); +} diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp --- a/mlir/lib/TableGen/Operator.cpp +++ b/mlir/lib/TableGen/Operator.cpp @@ -458,6 +458,13 @@ results.push_back({name, TypeConstraint(resultDef)}); if (!name.empty()) argumentsAndResultsIndex[name] = resultIndex(i); + + // We currently only support VariadicOfVariadic operands. + if (results.back().constraint.isVariadicOfVariadic()) { + PrintFatalError( + def.getLoc(), + "'VariadicOfVariadic' results are currently not supported"); + } } // Handle successors @@ -577,8 +584,7 @@ StringRef Operator::getAssemblyFormat() const { return TypeSwitch(def.getValueInit("assemblyFormat")) - .Case( - [&](auto *init) { return init->getValue(); }); + .Case([&](auto *init) { return init->getValue(); }); } void Operator::print(llvm::raw_ostream &os) const { diff --git a/mlir/lib/TableGen/Type.cpp b/mlir/lib/TableGen/Type.cpp --- a/mlir/lib/TableGen/Type.cpp +++ b/mlir/lib/TableGen/Type.cpp @@ -36,6 +36,15 @@ return def->isSubClassOf("Variadic"); } +bool TypeConstraint::isVariadicOfVariadic() const { + return def->isSubClassOf("VariadicOfVariadic"); +} + +StringRef TypeConstraint::getVariadicOfVariadicSegmentSizeAttr() const { + assert(isVariadicOfVariadic()); + return def->getValueAsString("segmentAttrName"); +} + // Returns the builder call for this constraint if this is a buildable type, // returns None otherwise. Optional TypeConstraint::getBuilderCall() const { diff --git a/mlir/test/IR/traits.mlir b/mlir/test/IR/traits.mlir --- a/mlir/test/IR/traits.mlir +++ b/mlir/test/IR/traits.mlir @@ -375,28 +375,28 @@ // ----- func @failedMissingOperandSizeAttr(%arg: i32) { - // expected-error @+1 {{requires 1D vector attribute 'operand_segment_sizes'}} + // expected-error @+1 {{requires 1D i32 elements attribute 'operand_segment_sizes'}} "test.attr_sized_operands"(%arg, %arg, %arg, %arg) : (i32, i32, i32, i32) -> () } // ----- func @failedOperandSizeAttrWrongType(%arg: i32) { - // expected-error @+1 {{requires 1D vector of i32 attribute 'operand_segment_sizes'}} - "test.attr_sized_operands"(%arg, %arg, %arg, %arg) {operand_segment_sizes = dense<[1, 1, 1, 1]>: tensor<4xi32>} : (i32, i32, i32, i32) -> () + // expected-error @+1 {{requires 1D i32 elements attribute 'operand_segment_sizes'}} + "test.attr_sized_operands"(%arg, %arg, %arg, %arg) {operand_segment_sizes = 10} : (i32, i32, i32, i32) -> () } // ----- func @failedOperandSizeAttrWrongRank(%arg: i32) { - // expected-error @+1 {{requires 1D vector of i32 attribute 'operand_segment_sizes'}} + // expected-error @+1 {{requires 1D i32 elements attribute 'operand_segment_sizes'}} "test.attr_sized_operands"(%arg, %arg, %arg, %arg) {operand_segment_sizes = dense<[[1, 1], [1, 1]]>: vector<2x2xi32>} : (i32, i32, i32, i32) -> () } // ----- func @failedOperandSizeAttrWrongElementType(%arg: i32) { - // expected-error @+1 {{requires 1D vector of i32 attribute 'operand_segment_sizes'}} + // expected-error @+1 {{requires 1D i32 elements attribute 'operand_segment_sizes'}} "test.attr_sized_operands"(%arg, %arg, %arg, %arg) {operand_segment_sizes = dense<[1, 1, 1, 1]>: vector<4xi64>} : (i32, i32, i32, i32) -> () } @@ -432,28 +432,28 @@ // ----- func @failedMissingResultSizeAttr() { - // expected-error @+1 {{requires 1D vector attribute 'result_segment_sizes'}} + // expected-error @+1 {{requires 1D i32 elements attribute 'result_segment_sizes'}} %0:4 = "test.attr_sized_results"() : () -> (i32, i32, i32, i32) } // ----- func @failedResultSizeAttrWrongType() { - // expected-error @+1 {{requires 1D vector of i32 attribute 'result_segment_sizes'}} - %0:4 = "test.attr_sized_results"() {result_segment_sizes = dense<[1, 1, 1, 1]>: tensor<4xi32>} : () -> (i32, i32, i32, i32) + // expected-error @+1 {{requires 1D i32 elements attribute 'result_segment_sizes'}} + %0:4 = "test.attr_sized_results"() {result_segment_sizes = 10} : () -> (i32, i32, i32, i32) } // ----- func @failedResultSizeAttrWrongRank() { - // expected-error @+1 {{requires 1D vector of i32 attribute 'result_segment_sizes'}} + // expected-error @+1 {{requires 1D i32 elements attribute 'result_segment_sizes'}} %0:4 = "test.attr_sized_results"() {result_segment_sizes = dense<[[1, 1], [1, 1]]>: vector<2x2xi32>} : () -> (i32, i32, i32, i32) } // ----- func @failedResultSizeAttrWrongElementType() { - // expected-error @+1 {{requires 1D vector of i32 attribute 'result_segment_sizes'}} + // expected-error @+1 {{requires 1D i32 elements attribute 'result_segment_sizes'}} %0:4 = "test.attr_sized_results"() {result_segment_sizes = dense<[1, 1, 1, 1]>: vector<4xi64>} : () -> (i32, i32, i32, i32) } diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1661,6 +1661,14 @@ let arguments = (ins Variadic:$operand); let assemblyFormat = [{ $operand `:` type($operand) attr-dict}]; } +def FormatVariadicOfVariadicOperand + : TEST_Op<"format_variadic_of_variadic_operand"> { + let arguments = (ins + VariadicOfVariadic:$operand, + I32ElementsAttr:$operand_segments + ); + let assemblyFormat = [{ $operand `:` type($operand) attr-dict}]; +} def FormatMultipleVariadicOperands : TEST_Op<"format_multiple_variadic_operands", [AttrSizedOperandSegments]> { diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir --- a/mlir/test/mlir-tblgen/op-format.mlir +++ b/mlir/test/mlir-tblgen/op-format.mlir @@ -151,6 +151,9 @@ // CHECK: test.format_variadic_operand %[[I64]], %[[I64]], %[[I64]] : i64, i64, i64 test.format_variadic_operand %i64, %i64, %i64 : i64, i64, i64 +// CHECK: test.format_variadic_of_variadic_operand (%[[I64]], %[[I64]]), (), (%[[I64]]) : (i64, i64), (), (i64) +test.format_variadic_of_variadic_operand (%i64, %i64), (), (%i64) : (i64, i64), (), (i64) + // CHECK: test.format_multiple_variadic_operands (%[[I64]], %[[I64]], %[[I64]]), (%[[I64]], %[[I32]] : i64, i32) test.format_multiple_variadic_operands (%i64, %i64, %i64), (%i64, %i32 : i64, i32) diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -24,6 +24,7 @@ #include "llvm/ADT/MapVector.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringSet.h" #include "llvm/Support/Signals.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Record.h" @@ -89,6 +90,23 @@ unsigned size = *(sizeAttrValues.begin() + index); return {start, size}; )"; +// The logic to calculate the actual value range for a declared operand +// of an op with variadic of variadic operands within the OpAdaptor. +// +// {0}: The name of the segment attribute. +// {1}: The index of the main operand. +const char *variadicOfVariadicAdaptorCalcCode = R"( + auto tblgenTmpOperands = getODSOperands({1}); + auto sizeAttrValues = {0}().getValues(); + auto sizeAttrIt = sizeAttrValues.begin(); + + ::llvm::SmallVector<::mlir::ValueRange> tblgenTmpOperandGroups; + for (int i = 0, e = ::llvm::size(sizeAttrValues); i < e; ++i, ++sizeAttrIt) {{ + tblgenTmpOperandGroups.push_back(tblgenTmpOperands.take_front(*sizeAttrIt)); + tblgenTmpOperands = tblgenTmpOperands.drop_front(*sizeAttrIt); + } + return tblgenTmpOperandGroups; +)"; // The logic to build a range of either operand or result values. // @@ -256,16 +274,20 @@ // Builds the parameter list for build() method of this op. This method writes // to `paramList` the comma-separated parameter list and updates // `resultTypeNames` with the names for parameters for specifying result - // types. The given `typeParamKind` and `attrParamKind` controls how result - // types and attributes are placed in the parameter list. + // types. `inferredAttributes` is populated with any attributes that are + // elided from the build list. The given `typeParamKind` and `attrParamKind` + // controls how result types and attributes are placed in the parameter list. void buildParamList(llvm::SmallVectorImpl ¶mList, + llvm::StringSet<> &inferredAttributes, SmallVectorImpl &resultTypeNames, TypeParamKind typeParamKind, AttrParamKind attrParamKind = AttrParamKind::WrappedAttr); // Adds op arguments and regions into operation state for build() methods. - void genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body, - bool isRawValueAttr = false); + void + genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body, + llvm::StringSet<> &inferredAttributes, + bool isRawValueAttr = false); // Generates canonicalizer declaration for the operation. void genCanonicalizerDecls(); @@ -783,7 +805,7 @@ // of ops, in particular for one-operand ops that may not have the // `getOperand(unsigned)` method. static void generateNamedOperandGetters(const Operator &op, Class &opClass, - StringRef sizeAttrInit, + bool isAdaptor, StringRef sizeAttrInit, StringRef rangeType, StringRef rangeBeginCall, StringRef rangeSizeCall, @@ -838,6 +860,20 @@ m->body() << " auto operands = getODSOperands(" << i << ");\n" << " return operands.empty() ? ::mlir::Value() : *operands.begin();"; + } else if (operand.isVariadicOfVariadic()) { + StringRef segmentAttr = + operand.constraint.getVariadicOfVariadicSegmentSizeAttr(); + if (isAdaptor) { + m = opClass.addMethodAndPrune("::llvm::SmallVector<::mlir::ValueRange>", + operand.name); + m->body() << llvm::formatv(variadicOfVariadicAdaptorCalcCode, + segmentAttr, i); + continue; + } + + m = opClass.addMethodAndPrune("::mlir::OperandRangeRange", operand.name); + m->body() << " return getODSOperands(" << i << ").split(" << segmentAttr + << "Attr());"; } else if (operand.isVariadic()) { m = opClass.addMethodAndPrune(rangeType, operand.name); m->body() << " return getODSOperands(" << i << ");"; @@ -860,6 +896,7 @@ generateNamedOperandGetters( op, opClass, + /*isAdaptor=*/false, /*sizeAttrInit=*/attrSizeInitCode, /*rangeType=*/"::mlir::Operation::operand_range", /*rangeBeginCall=*/"getOperation()->operand_begin()", @@ -874,17 +911,32 @@ const auto &operand = op.getOperand(i); if (operand.name.empty()) continue; - auto *m = opClass.addMethodAndPrune("::mlir::MutableOperandRange", + auto *m = opClass.addMethodAndPrune(operand.isVariadicOfVariadic() + ? "::mlir::MutableOperandRangeRange" + : "::mlir::MutableOperandRange", (operand.name + "Mutable").str()); auto &body = m->body(); body << " auto range = getODSOperandIndexAndLength(" << i << ");\n" - << " return ::mlir::MutableOperandRange(getOperation(), " + << " auto mutableRange = ::mlir::MutableOperandRange(getOperation(), " "range.first, range.second"; if (attrSizedOperands) body << ", ::mlir::MutableOperandRange::OperandSegment(" << i << "u, *getOperation()->getAttrDictionary().getNamed(" "operand_segment_sizesAttrName()))"; body << ");\n"; + + // If this operand is a nested variadic, we split the range into a + // MutableOperandRangeRange that provides a range over all of the + // sub-ranges. + if (operand.isVariadicOfVariadic()) { + body << " return " + "mutableRange.split(*(*this)->getAttrDictionary().getNamed(" + << operand.constraint.getVariadicOfVariadicSegmentSizeAttr() + << "AttrName()));\n"; + } else { + // Otherwise, we use the full range directly. + body << " return mutableRange;\n"; + } } } @@ -1038,7 +1090,9 @@ bool inferType) { llvm::SmallVector paramList; llvm::SmallVector resultNames; - buildParamList(paramList, resultNames, paramKind, attrType); + llvm::StringSet<> inferredAttributes; + buildParamList(paramList, inferredAttributes, resultNames, paramKind, + attrType); auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static, std::move(paramList)); @@ -1046,8 +1100,9 @@ if (!m) return; auto &body = m->body(); - genCodeForAddingArgAndRegionForBuilder( - body, /*isRawValueAttr=*/attrType == AttrParamKind::UnwrappedValue); + genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes, + /*isRawValueAttr=*/attrType == + AttrParamKind::UnwrappedValue); // Push all result types to the operation state @@ -1215,7 +1270,9 @@ void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() { llvm::SmallVector paramList; llvm::SmallVector resultNames; - buildParamList(paramList, resultNames, TypeParamKind::None); + llvm::StringSet<> inferredAttributes; + buildParamList(paramList, inferredAttributes, resultNames, + TypeParamKind::None); auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static, std::move(paramList)); @@ -1223,7 +1280,7 @@ if (!m) return; auto &body = m->body(); - genCodeForAddingArgAndRegionForBuilder(body); + genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes); auto numResults = op.getNumResults(); if (numResults == 0) @@ -1415,6 +1472,7 @@ } void OpEmitter::buildParamList(SmallVectorImpl ¶mList, + llvm::StringSet<> &inferredAttributes, SmallVectorImpl &resultTypeNames, TypeParamKind typeParamKind, AttrParamKind attrParamKind) { @@ -1453,10 +1511,6 @@ } // Add parameters for all arguments (operands and attributes). - - int numOperands = 0; - int numAttrs = 0; - int defaultValuedAttrStartIndex = op.getNumArgs(); if (attrParamKind == AttrParamKind::UnwrappedValue) { // Calculate the start index from which we can attach default values in the @@ -1482,54 +1536,68 @@ } } - for (int i = 0, e = op.getNumArgs(); i < e; ++i) { - auto argument = op.getArg(i); - if (argument.is()) { - const auto &operand = op.getOperand(numOperands); - StringRef type = - operand.isVariadic() ? "::mlir::ValueRange" : "::mlir::Value"; - OpMethodParameter::Property properties = OpMethodParameter::PP_None; - if (operand.isOptional()) - properties = OpMethodParameter::PP_Optional; + /// Collect any inferred attributes. + for (const NamedTypeConstraint &operand : op.getOperands()) { + if (operand.isVariadicOfVariadic()) { + inferredAttributes.insert( + operand.constraint.getVariadicOfVariadicSegmentSizeAttr()); + } + } - paramList.emplace_back(type, getArgumentName(op, numOperands), - properties); - ++numOperands; - } else { - const auto &namedAttr = op.getAttribute(numAttrs); - const auto &attr = namedAttr.attr; + for (int i = 0, e = op.getNumArgs(), numOperands = 0; i < e; ++i) { + Argument arg = op.getArg(i); + if (const auto *operand = arg.dyn_cast()) { + StringRef type; + if (operand->isVariadicOfVariadic()) + type = "::llvm::ArrayRef<::mlir::ValueRange>"; + else if (operand->isVariadic()) + type = "::mlir::ValueRange"; + else + type = "::mlir::Value"; OpMethodParameter::Property properties = OpMethodParameter::PP_None; - if (attr.isOptional()) + if (operand->isOptional()) properties = OpMethodParameter::PP_Optional; + paramList.emplace_back(type, getArgumentName(op, numOperands++), + properties); + continue; + } + const NamedAttribute &namedAttr = *arg.get(); + const Attribute &attr = namedAttr.attr; - StringRef type; - switch (attrParamKind) { - case AttrParamKind::WrappedAttr: + // inferred attributes don't need to be added to the param list. + if (inferredAttributes.contains(namedAttr.name)) + continue; + + OpMethodParameter::Property properties = OpMethodParameter::PP_None; + if (attr.isOptional()) + properties = OpMethodParameter::PP_Optional; + + StringRef type; + switch (attrParamKind) { + case AttrParamKind::WrappedAttr: + type = attr.getStorageType(); + break; + case AttrParamKind::UnwrappedValue: + if (canUseUnwrappedRawValue(attr)) + type = attr.getReturnType(); + else type = attr.getStorageType(); - break; - case AttrParamKind::UnwrappedValue: - if (canUseUnwrappedRawValue(attr)) - type = attr.getReturnType(); - else - type = attr.getStorageType(); - break; - } + break; + } - std::string defaultValue; - // Attach default value if requested and possible. - if (attrParamKind == AttrParamKind::UnwrappedValue && - i >= defaultValuedAttrStartIndex) { - bool isString = attr.getReturnType() == "::llvm::StringRef"; - if (isString) - defaultValue.append("\""); - defaultValue += attr.getDefaultValue(); - if (isString) - defaultValue.append("\""); - } - paramList.emplace_back(type, namedAttr.name, defaultValue, properties); - ++numAttrs; + // Attach default value if requested and possible. + std::string defaultValue; + if (attrParamKind == AttrParamKind::UnwrappedValue && + i >= defaultValuedAttrStartIndex) { + bool isString = attr.getReturnType() == "::llvm::StringRef"; + if (isString) + defaultValue.append("\""); + defaultValue += attr.getDefaultValue(); + if (isString) + defaultValue.append("\""); } + paramList.emplace_back(type, namedAttr.name, defaultValue, properties); } /// Insert parameters for each successor. @@ -1546,12 +1614,31 @@ llvm::formatv("{0}Count", region.name).str()); } -void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body, - bool isRawValueAttr) { +void OpEmitter::genCodeForAddingArgAndRegionForBuilder( + OpMethodBody &body, llvm::StringSet<> &inferredAttributes, + bool isRawValueAttr) { // Push all operands to the result. for (int i = 0, e = op.getNumOperands(); i < e; ++i) { std::string argName = getArgumentName(op, i); - if (op.getOperand(i).isOptional()) + NamedTypeConstraint &operand = op.getOperand(i); + if (operand.constraint.isVariadicOfVariadic()) { + body << " for (::mlir::ValueRange range : " << argName << ")\n " + << builderOpState << ".addOperands(range);\n"; + + // Add the segment attribute. + body << " {\n" + << " SmallVector rangeSegments;\n" + << " for (::mlir::ValueRange range : " << argName << ")\n" + << " rangeSegments.push_back(range.size());\n" + << " " << builderOpState << ".addAttribute(" + << operand.constraint.getVariadicOfVariadicSegmentSizeAttr() + << "AttrName(" << builderOpState << ".name), " << odsBuilder + << ".getI32TensorAttr(rangeSegments));" + << " }\n"; + continue; + } + + if (operand.isOptional()) body << " if (" << argName << ")\n "; body << " " << builderOpState << ".addOperands(" << argName << ");\n"; } @@ -1563,12 +1650,24 @@ << ".name), " << "odsBuilder.getI32VectorAttr({"; interleaveComma(llvm::seq(0, op.getNumOperands()), body, [&](int i) { - if (op.getOperand(i).isOptional()) - body << "(" << getArgumentName(op, i) << " ? 1 : 0)"; - else if (op.getOperand(i).isVariadic()) - body << "static_cast(" << getArgumentName(op, i) << ".size())"; - else + const NamedTypeConstraint &operand = op.getOperand(i); + if (!operand.isVariableLength()) { body << "1"; + return; + } + + std::string operandName = getArgumentName(op, i); + if (operand.isOptional()) { + body << "(" << operandName << " ? 1 : 0)"; + } else if (operand.isVariadicOfVariadic()) { + body << llvm::formatv( + "static_cast(std::accumulate({0}.begin(), {0}.end(), 0, " + "[](int32_t curSum, ::mlir::ValueRange range) {{ return curSum + " + "range.size(); }))", + operandName); + } else { + body << "static_cast(" << getArgumentName(op, i) << ".size())"; + } }); body << "}));\n"; } @@ -1576,38 +1675,38 @@ // Push all attributes to the result. for (const auto &namedAttr : op.getAttributes()) { auto &attr = namedAttr.attr; - if (!attr.isDerivedAttr()) { - bool emitNotNullCheck = attr.isOptional(); - if (emitNotNullCheck) - body << formatv(" if ({0}) ", namedAttr.name) << "{\n"; - - if (isRawValueAttr && canUseUnwrappedRawValue(attr)) { - // If this is a raw value, then we need to wrap it in an Attribute - // instance. - FmtContext fctx; - fctx.withBuilder("odsBuilder"); - - std::string builderTemplate = - std::string(attr.getConstBuilderTemplate()); - - // For StringAttr, its constant builder call will wrap the input in - // quotes, which is correct for normal string literals, but incorrect - // here given we use function arguments. So we need to strip the - // wrapping quotes. - if (StringRef(builderTemplate).contains("\"$0\"")) - builderTemplate = replaceAllSubstrs(builderTemplate, "\"$0\"", "$0"); - - std::string value = - std::string(tgfmt(builderTemplate, &fctx, namedAttr.name)); - body << formatv(" {0}.addAttribute({1}AttrName({0}.name), {2});\n", - builderOpState, namedAttr.name, value); - } else { - body << formatv(" {0}.addAttribute({1}AttrName({0}.name), {1});\n", - builderOpState, namedAttr.name); - } - if (emitNotNullCheck) - body << " }\n"; + if (attr.isDerivedAttr() || inferredAttributes.contains(namedAttr.name)) + continue; + + bool emitNotNullCheck = attr.isOptional(); + if (emitNotNullCheck) + body << formatv(" if ({0}) ", namedAttr.name) << "{\n"; + + if (isRawValueAttr && canUseUnwrappedRawValue(attr)) { + // If this is a raw value, then we need to wrap it in an Attribute + // instance. + FmtContext fctx; + fctx.withBuilder("odsBuilder"); + + std::string builderTemplate = std::string(attr.getConstBuilderTemplate()); + + // For StringAttr, its constant builder call will wrap the input in + // quotes, which is correct for normal string literals, but incorrect + // here given we use function arguments. So we need to strip the + // wrapping quotes. + if (StringRef(builderTemplate).contains("\"$0\"")) + builderTemplate = replaceAllSubstrs(builderTemplate, "\"$0\"", "$0"); + + std::string value = + std::string(tgfmt(builderTemplate, &fctx, namedAttr.name)); + body << formatv(" {0}.addAttribute({1}AttrName({0}.name), {2});\n", + builderOpState, namedAttr.name, value); + } else { + body << formatv(" {0}.addAttribute({1}AttrName({0}.name), {1});\n", + builderOpState, namedAttr.name); } + if (emitNotNullCheck) + body << " }\n"; } // Create the correct number of regions. @@ -1960,9 +2059,12 @@ body << " unsigned index = 0; (void)index;\n"; for (auto staticValue : llvm::enumerate(values)) { - bool hasPredicate = staticValue.value().hasPredicate(); - bool isOptional = staticValue.value().isOptional(); - if (!hasPredicate && !isOptional) + const NamedTypeConstraint &value = staticValue.value(); + + bool hasPredicate = value.hasPredicate(); + bool isOptional = value.isOptional(); + bool isVariadicOfVariadic = value.isVariadicOfVariadic(); + if (!hasPredicate && !isOptional && !isVariadicOfVariadic) continue; body << formatv(" auto valueGroup{2} = getODS{0}{1}s({2});\n", // Capitalize the first letter to match the function name @@ -1977,14 +2079,21 @@ "<< index << \" requires 0 or 1 element, but found \" << " "valueGroup{0}.size();\n", staticValue.index(), valueKind); + } else if (isVariadicOfVariadic) { + body << formatv( + " if (::mlir::failed(::mlir::OpTrait::impl::verifyValueSizeAttr(" + "*this, \"{0}\", \"{1}\", valueGroup{2}.size())))\n" + " return ::mlir::failure();\n", + value.constraint.getVariadicOfVariadicSegmentSizeAttr(), value.name, + staticValue.index()); } // Otherwise, if there is no predicate there is nothing left to do. if (!hasPredicate) continue; // Emit a loop to check all the dynamic values in the pack. - StringRef constraintFn = staticVerifierEmitter.getTypeConstraintFn( - staticValue.value().constraint); + StringRef constraintFn = + staticVerifierEmitter.getTypeConstraintFn(value.constraint); body << " for (::mlir::Value v : valueGroup" << staticValue.index() << ") {\n" << " if (::mlir::failed(" << constraintFn @@ -2257,7 +2366,8 @@ } std::string sizeAttrInit = formatv(adapterSegmentSizeAttrInitCode, "operand_segment_sizes"); - generateNamedOperandGetters(op, adaptor, sizeAttrInit, + generateNamedOperandGetters(op, adaptor, + /*isAdaptor=*/true, sizeAttrInit, /*rangeType=*/"::mlir::ValueRange", /*rangeBeginCall=*/"odsOperands.begin()", /*rangeSizeCall=*/"odsOperands.size()", diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -497,6 +497,7 @@ /// The set of attributes explicitly used within the format. SmallVector usedAttributes; + llvm::StringSet<> inferredAttributes; }; } // end anonymous namespace @@ -616,10 +617,38 @@ if (parser.parseOperand({0}RawOperands[0])) return ::mlir::failure(); )"; +/// The code snippet used to generate a parser call for a VariadicOfVariadic +/// operand. +/// +/// {0}: The name of the operand. +/// {1}: The name of segment size attribute. +const char *const variadicOfVariadicOperandParserCode = R"( + { + {0}OperandsLoc = parser.getCurrentLocation(); + int32_t curSize = 0; + do { + if (parser.parseOptionalLParen()) + break; + if (parser.parseOperandList({0}Operands) || parser.parseRParen()) + return ::mlir::failure(); + {0}OperandGroupSizes.push_back({0}Operands.size() - curSize); + curSize = {0}Operands.size(); + } while (succeeded(parser.parseOptionalComma())); + } +)"; /// The code snippet used to generate a parser call for a type list. /// /// {0}: The name for the type list. +const char *const variadicOfVariadicTypeParserCode = R"( + do { + if (parser.parseOptionalLParen()) + break; + if (parser.parseOptionalRParen() && + (parser.parseTypeList({0}Types) || parser.parseRParen())) + return ::mlir::failure(); + } while (succeeded(parser.parseOptionalComma())); +)"; const char *const variadicTypeParserCode = R"( if (parser.parseTypeList({0}Types)) return ::mlir::failure(); @@ -758,6 +787,9 @@ namespace { /// The type of length for a given parse argument. enum class ArgumentLengthKind { + /// The argument is a variadic of a variadic, and may contain 0->N range + /// elements. + VariadicOfVariadic, /// The argument is variadic, and may contain 0->N elements. Variadic, /// The argument is optional, and may contain 0 or 1 elements. @@ -772,6 +804,8 @@ getArgumentLengthKind(const NamedTypeConstraint *var) { if (var->isOptional()) return ArgumentLengthKind::Optional; + if (var->isVariadicOfVariadic()) + return ArgumentLengthKind::VariadicOfVariadic; if (var->isVariadic()) return ArgumentLengthKind::Variadic; return ArgumentLengthKind::Single; @@ -863,6 +897,10 @@ if (operand->getVar()->isVariableLength()) { body << " ::mlir::SmallVector<::mlir::OpAsmParser::OperandType, 4> " << name << "Operands;\n"; + if (operand->getVar()->isVariadicOfVariadic()) { + body << " llvm::SmallVector " << name + << "OperandGroupSizes;\n"; + } } else { body << " ::mlir::OpAsmParser::OperandType " << name << "RawOperands[1];\n" @@ -924,7 +962,9 @@ } else if (auto *operand = dyn_cast(¶m)) { StringRef name = operand->getVar()->name; ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar()); - if (lengthKind == ArgumentLengthKind::Variadic) + if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) + body << llvm::formatv("{0}OperandGroups", name); + else if (lengthKind == ArgumentLengthKind::Variadic) body << llvm::formatv("{0}Operands", name); else if (lengthKind == ArgumentLengthKind::Optional) body << llvm::formatv("{0}Operand", name); @@ -951,7 +991,9 @@ } else if (auto *dir = dyn_cast(¶m)) { ArgumentLengthKind lengthKind; StringRef listName = getTypeListName(dir->getOperand(), lengthKind); - if (lengthKind == ArgumentLengthKind::Variadic) + if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) + body << llvm::formatv("{0}TypeGroups", listName); + else if (lengthKind == ArgumentLengthKind::Variadic) body << llvm::formatv("{0}Types", listName); else if (lengthKind == ArgumentLengthKind::Optional) body << llvm::formatv("{0}Type", listName); @@ -972,19 +1014,32 @@ // * Set the location of operand variables. for (Element ¶m : dir->getArguments()) { if (auto *operand = dyn_cast(¶m)) { - body << " " << operand->getVar()->name + auto *var = operand->getVar(); + body << " " << var->name << "OperandsLoc = parser.getCurrentLocation();\n"; - if (operand->getVar()->isOptional()) { + if (var->isOptional()) { body << llvm::formatv( " llvm::Optional<::mlir::OpAsmParser::OperandType> " "{0}Operand;\n", - operand->getVar()->name); + var->name); + } else if (var->isVariadicOfVariadic()) { + body << llvm::formatv(" " + "llvm::SmallVector> " + "{0}OperandGroups;\n", + var->name); } } else if (auto *dir = dyn_cast(¶m)) { ArgumentLengthKind lengthKind; StringRef listName = getTypeListName(dir->getOperand(), lengthKind); - if (lengthKind == ArgumentLengthKind::Optional) + if (lengthKind == ArgumentLengthKind::Optional) { body << llvm::formatv(" ::mlir::Type {0}Type;\n", listName); + } else if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) { + body << llvm::formatv( + " llvm::SmallVector> " + "{0}TypeGroups;\n", + listName); + } } else if (auto *dir = dyn_cast(¶m)) { Element *input = dir->getOperand(); if (auto *operand = dyn_cast(input)) { @@ -1028,11 +1083,18 @@ var->name); } else if (auto *operand = dyn_cast(¶m)) { const NamedTypeConstraint *var = operand->getVar(); - if (!var->isOptional()) - continue; - body << llvm::formatv(" if ({0}Operand.hasValue())\n" - " {0}Operands.push_back(*{0}Operand);\n", - var->name); + if (var->isOptional()) { + body << llvm::formatv(" if ({0}Operand.hasValue())\n" + " {0}Operands.push_back(*{0}Operand);\n", + var->name); + } else if (var->isVariadicOfVariadic()) { + body << llvm::formatv( + " for (const auto &subRange : {0}OperandGroups) {{\n" + " {0}Operands.append(subRange.begin(), subRange.end());\n" + " {0}OperandGroupSizes.push_back(subRange.size());\n" + " }\n", + var->name, var->constraint.getVariadicOfVariadicSegmentSizeAttr()); + } } else if (auto *dir = dyn_cast(¶m)) { ArgumentLengthKind lengthKind; StringRef listName = getTypeListName(dir->getOperand(), lengthKind); @@ -1040,6 +1102,11 @@ body << llvm::formatv(" if ({0}Type)\n" " {0}Types.push_back({0}Type);\n", listName); + } else if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) { + body << llvm::formatv( + " for (const auto &subRange : {0}TypeGroups)\n" + " {0}Types.append(subRange.begin(), subRange.end());\n", + listName); } } } @@ -1229,7 +1296,11 @@ } else if (auto *operand = dyn_cast(element)) { ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar()); StringRef name = operand->getVar()->name; - if (lengthKind == ArgumentLengthKind::Variadic) + if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) + body << llvm::formatv( + variadicOfVariadicOperandParserCode, name, + operand->getVar()->constraint.getVariadicOfVariadicSegmentSizeAttr()); + else if (lengthKind == ArgumentLengthKind::Variadic) body << llvm::formatv(variadicOperandParserCode, name); else if (lengthKind == ArgumentLengthKind::Optional) body << llvm::formatv(optionalOperandParserCode, name); @@ -1281,7 +1352,9 @@ } else if (auto *dir = dyn_cast(element)) { ArgumentLengthKind lengthKind; StringRef listName = getTypeListName(dir->getOperand(), lengthKind); - if (lengthKind == ArgumentLengthKind::Variadic) + if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) + body << llvm::formatv(variadicOfVariadicTypeParserCode, listName); + else if (lengthKind == ArgumentLengthKind::Variadic) body << llvm::formatv(variadicTypeParserCode, listName); else if (lengthKind == ArgumentLengthKind::Optional) body << llvm::formatv(optionalTypeParserCode, listName); @@ -1501,19 +1574,29 @@ void OperationFormat::genParserVariadicSegmentResolution(Operator &op, OpMethodBody &body) { - if (!allOperands && - op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { - body << " result.addAttribute(\"operand_segment_sizes\", " - << "parser.getBuilder().getI32VectorAttr({"; - auto interleaveFn = [&](const NamedTypeConstraint &operand) { - // If the operand is variadic emit the parsed size. - if (operand.isVariableLength()) - body << "static_cast(" << operand.name << "Operands.size())"; - else - body << "1"; - }; - llvm::interleaveComma(op.getOperands(), body, interleaveFn); - body << "}));\n"; + if (!allOperands) { + if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { + body << " result.addAttribute(\"operand_segment_sizes\", " + << "parser.getBuilder().getI32VectorAttr({"; + auto interleaveFn = [&](const NamedTypeConstraint &operand) { + // If the operand is variadic emit the parsed size. + if (operand.isVariableLength()) + body << "static_cast(" << operand.name << "Operands.size())"; + else + body << "1"; + }; + llvm::interleaveComma(op.getOperands(), body, interleaveFn); + body << "}));\n"; + } + for (const NamedTypeConstraint &operand : op.getOperands()) { + if (!operand.isVariadicOfVariadic()) + continue; + body << llvm::formatv( + " result.addAttribute(\"{0}\", " + "parser.getBuilder().getI32TensorAttr({1}OperandGroupSizes));\n", + operand.constraint.getVariadicOfVariadicSegmentSizeAttr(), + operand.name); + } } if (!allResultTypes && @@ -1575,6 +1658,10 @@ if (!fmt.allResultTypes && op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) body << "\"result_segment_sizes\", "; + if (!fmt.inferredAttributes.empty()) { + for (const auto &attr : fmt.inferredAttributes) + body << "\"" << attr.getKey() << "\", "; + } llvm::interleaveComma( fmt.usedAttributes, body, [&](const NamedAttribute *attr) { body << "\"" << attr->name << "\""; }); @@ -1693,6 +1780,8 @@ return body << "getOperation()->getResultTypes()"; auto *operand = dyn_cast(arg); auto *var = operand ? operand->getVar() : cast(arg)->getVar(); + if (var->isVariadicOfVariadic()) + return body << llvm::formatv("{0}().join().getTypes()", var->name); if (var->isVariadic()) return body << var->name << "().getTypes()"; if (var->isOptional()) @@ -1896,7 +1985,12 @@ else body << " p.printAttribute(" << var->name << "Attr());\n"; } else if (auto *operand = dyn_cast(element)) { - if (operand->getVar()->isOptional()) { + if (operand->getVar()->isVariadicOfVariadic()) { + body << " ::llvm::interleaveComma(" << operand->getVar()->name + << "(), p, [&](const auto &operands) { p << \"(\" << operands << " + "\")\"; });\n"; + + } else if (operand->getVar()->isOptional()) { body << " if (::mlir::Value value = " << operand->getVar()->name << "())\n" << " p << value;\n"; @@ -1926,6 +2020,15 @@ } else if (isa(element)) { body << " ::llvm::interleaveComma(getOperation()->getSuccessors(), p);\n"; } else if (auto *dir = dyn_cast(element)) { + if (auto *operand = dyn_cast(dir->getOperand())) { + if (operand->getVar()->isVariadicOfVariadic()) { + body << llvm::formatv(" ::llvm::interleaveComma({0}().getTypes(), p, " + "[&](::mlir::TypeRange types) {{ p << \"(\" << " + "types << \")\"; });\n", + operand->getVar()->name); + return; + } + } body << " p << "; genTypeOperandPrinter(dir->getOperand(), body) << ";\n"; } else if (auto *dir = dyn_cast(element)) { @@ -2449,6 +2552,16 @@ while (!iteratorStack.empty()) if (failed(verifyAttributes(loc, iteratorStack))) return ::mlir::failure(); + + // Check for VariadicOfVariadic variables. The segment attribute of those + // variables will be infered. + for (const NamedTypeConstraint *var : seenOperands) { + if (var->constraint.isVariadicOfVariadic()) { + fmt.inferredAttributes.insert( + var->constraint.getVariadicOfVariadicSegmentSizeAttr()); + } + } + return ::mlir::success(); } /// Verify the attribute elements at the back of the given stack of iterators.