diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -296,6 +296,10 @@ Attribute get(StringRef name) const; Attribute get(Identifier name) const; + /// Return the specified named attribute if present, None otherwise. + Optional getNamed(StringRef name) const; + Optional getNamed(Identifier name) const; + /// Support range iteration. using iterator = llvm::ArrayRef::iterator; iterator begin() const; @@ -1508,6 +1512,10 @@ Attribute get(StringRef name) const; Attribute get(Identifier name) const; + /// Return the specified named attribute if present, None otherwise. + Optional getNamed(StringRef name) const; + Optional getNamed(Identifier name) const; + /// If the an attribute exists with the specified name, change it to the new /// value. Otherwise, add a new attribute with the specified name/value. void set(Identifier name, Attribute value); diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -205,6 +205,14 @@ /// 'operands'. void setOperands(ValueRange operands); + /// Replace the operands beginning at 'start' and ending at 'start' + 'length' + /// with the ones provided in 'operands'. 'operands' may be smaller or larger + /// than the range pointed to by 'start'+'length'. + void setOperands(unsigned start, unsigned length, ValueRange operands); + + /// Insert the given operands into the operand list at the given 'index'. + void insertOperands(unsigned index, ValueRange operands); + unsigned getNumOperands() { return LLVM_LIKELY(hasOperandStorage) ? getOperandStorage().size() : 0; } @@ -214,6 +222,15 @@ return getOpOperand(idx).set(value); } + /// Erase the operand at position `idx`. + void eraseOperand(unsigned idx) { eraseOperands(idx); } + + /// Erase the operands starting at position `idx` and ending at position + /// 'idx'+'length'. + void eraseOperands(unsigned idx, unsigned length = 1) { + getOperandStorage().eraseOperands(idx, length); + } + // Support operand iteration. using operand_range = OperandRange; using operand_iterator = operand_range::iterator; @@ -221,12 +238,9 @@ operand_iterator operand_begin() { return getOperands().begin(); } operand_iterator operand_end() { return getOperands().end(); } - /// Returns an iterator on the underlying Value's (Value ). + /// Returns an iterator on the underlying Value's. operand_range getOperands() { return operand_range(this); } - /// Erase the operand at position `idx`. - void eraseOperand(unsigned idx) { getOperandStorage().eraseOperand(idx); } - MutableArrayRef getOpOperands() { return LLVM_LIKELY(hasOperandStorage) ? getOperandStorage().getOperands() : MutableArrayRef(); diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -369,8 +369,14 @@ /// 'values'. void setOperands(Operation *owner, ValueRange values); - /// Erase an operand held by the storage. - void eraseOperand(unsigned index); + /// Replace the operands beginning at 'start' and ending at 'start' + 'length' + /// with the ones provided in 'operands'. 'operands' may be smaller or larger + /// than the range pointed to by 'start'+'length'. + void setOperands(Operation *owner, unsigned start, unsigned length, + ValueRange operands); + + /// Erase the operands held by the storage within the given range. + void eraseOperands(unsigned start, unsigned length); /// Get the operation operands held by the storage. MutableArrayRef getOperands() { @@ -653,6 +659,62 @@ friend RangeBaseT; }; +//===----------------------------------------------------------------------===// +// MutableOperandRange + +/// This class provides a mutable adaptor for a range of operands. It allows for +/// setting, inserting, and erasing operands from the given range. +class MutableOperandRange { +public: + /// A pair of a named attribute corresponding to an operand segment attribute, + /// and the index within that attribute. The attribute should correspond to an + /// i32 DenseElementsAttr. + using OperandSegment = std::pair; + + /// Construct a new mutable range from the given operand, operand start index, + /// and range length. `operandSegments` is an optional set of operand segments + /// to be updated when mutating the operand list. + MutableOperandRange(Operation *owner, unsigned start, unsigned length, + ArrayRef operandSegments = llvm::None); + MutableOperandRange(Operation *owner); + + /// Append the given values to the range. + void append(ValueRange values); + + /// Assign this range to the given values. + void assign(ValueRange values); + + /// Assign the range to the given value. + void assign(Value value); + + /// Erase the operands within the given sub-range. + void erase(unsigned subStart, unsigned subLen = 1); + + /// Clear this range and erase all of the operands. + void clear(); + + /// Returns the current size of the range. + unsigned size() const { return length; } + + /// Allow implicit conversion to an OperandRange. + operator OperandRange() const; + +private: + /// Update the length of this range to the one provided. + void updateLength(unsigned newLength); + + /// The owning operation of this range. + Operation *owner; + + /// The start index of the operand range within the owner operand list, and + /// the length starting from `start`. + unsigned start, length; + + /// Optional set of operand segments that should be updated when mutating the + /// length of this range. + SmallVector, 1> operandSegments; +}; + //===----------------------------------------------------------------------===// // ResultRange diff --git a/mlir/include/mlir/IR/UseDefLists.h b/mlir/include/mlir/IR/UseDefLists.h --- a/mlir/include/mlir/IR/UseDefLists.h +++ b/mlir/include/mlir/IR/UseDefLists.h @@ -164,7 +164,8 @@ other.back = nullptr; nextUse = nullptr; back = nullptr; - insertIntoCurrent(); + if (value) + insertIntoCurrent(); return *this; } diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -180,15 +180,26 @@ /// Return the specified attribute if present, null otherwise. Attribute DictionaryAttr::get(StringRef name) const { + Optional attr = getNamed(name); + return attr ? attr->second : nullptr; +} +Attribute DictionaryAttr::get(Identifier name) const { + Optional attr = getNamed(name); + return attr ? attr->second : nullptr; +} + +/// Return the specified named attribute if present, None otherwise. +Optional DictionaryAttr::getNamed(StringRef name) const { ArrayRef values = getValue(); auto it = llvm::lower_bound(values, name, compareNamedAttributeWithName); - return it != values.end() && it->first == name ? it->second : Attribute(); + return it != values.end() && it->first == name ? *it + : Optional(); } -Attribute DictionaryAttr::get(Identifier name) const { +Optional DictionaryAttr::getNamed(Identifier name) const { for (auto elt : getValue()) if (elt.first == name) - return elt.second; - return nullptr; + return elt; + return llvm::None; } DictionaryAttr::iterator DictionaryAttr::begin() const { @@ -1174,6 +1185,14 @@ return attrs ? attrs.get(name) : nullptr; } +/// Return the specified named attribute if present, None otherwise. +Optional NamedAttributeList::getNamed(StringRef name) const { + return attrs ? attrs.getNamed(name) : Optional(); +} +Optional NamedAttributeList::getNamed(Identifier name) const { + return attrs ? attrs.getNamed(name) : Optional(); +} + /// If the an attribute exists with the specified name, change it to the new /// value. Otherwise, add a new attribute with the specified name/value. void NamedAttributeList::set(Identifier name, Attribute value) { diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -243,6 +243,25 @@ assert(operands.empty() && "setting operands without an operand storage"); } +/// Replace the operands beginning at 'start' and ending at 'start' + 'length' +/// with the ones provided in 'operands'. 'operands' may be smaller or larger +/// than the range pointed to by 'start'+'length'. +void Operation::setOperands(unsigned start, unsigned length, + ValueRange operands) { + assert((start + length) <= getNumOperands() && + "invalid operand range specified"); + if (LLVM_LIKELY(hasOperandStorage)) + return getOperandStorage().setOperands(this, start, length, operands); + assert(operands.empty() && "setting operands without an operand storage"); +} + +/// Insert the given operands into the operand list at the given 'index'. +void Operation::insertOperands(unsigned index, ValueRange operands) { + if (LLVM_LIKELY(hasOperandStorage)) + return setOperands(index, /*length=*/0, operands); + assert(operands.empty() && "inserting operands without an operand storage"); +} + //===----------------------------------------------------------------------===// // Diagnostics //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -13,7 +13,9 @@ #include "mlir/IR/OperationSupport.h" #include "mlir/IR/Block.h" +#include "mlir/IR/OpDefinition.h" #include "mlir/IR/Operation.h" +#include "mlir/IR/StandardTypes.h" using namespace mlir; //===----------------------------------------------------------------------===// @@ -89,6 +91,55 @@ storageOperands[i].set(values[i]); } +/// Replace the operands beginning at 'start' and ending at 'start' + 'length' +/// with the ones provided in 'operands'. 'operands' may be smaller or larger +/// than the range pointed to by 'start'+'length'. +void detail::OperandStorage::setOperands(Operation *owner, unsigned start, + unsigned length, ValueRange operands) { + // If the new size is the same, we can update inplace. + unsigned newSize = operands.size(); + if (newSize == length) { + MutableArrayRef storageOperands = getOperands(); + for (unsigned i = 0, e = length; i != e; ++i) + storageOperands[start + i].set(operands[i]); + return; + } + // If the new size is greater, remove the extra operands and set the rest + // inplace. + if (newSize < length) { + eraseOperands(start + operands.size(), length - newSize); + setOperands(owner, start, newSize, operands); + return; + } + // Otherwise, the new size is greater so we need to grow the storage. + auto storageOperands = resize(owner, size() + (newSize - length)); + + // Shift operands to the right to make space for the new operands. + unsigned rotateSize = storageOperands.size() - (start + length); + auto rbegin = storageOperands.rbegin(); + std::rotate(rbegin, std::next(rbegin, newSize - length), rbegin + rotateSize); + + // Update the operands inplace. + for (unsigned i = 0, e = operands.size(); i != e; ++i) + storageOperands[start + i].set(operands[i]); +} + +/// Erase an operand held by the storage. +void detail::OperandStorage::eraseOperands(unsigned start, unsigned length) { + TrailingOperandStorage &storage = getStorage(); + MutableArrayRef operands = storage.getOperands(); + assert((start + length) <= operands.size()); + storage.numOperands -= length; + + // Shift all operands down if the operand to remove is not at the end. + if (start != storage.numOperands) { + auto indexIt = std::next(operands.begin(), start); + std::rotate(indexIt, std::next(indexIt, length), operands.end()); + } + for (unsigned i = 0; i != length; ++i) + operands[storage.numOperands + i].~OpOperand(); +} + /// Resize the storage to the given size. Returns the array containing the new /// operands. MutableArrayRef detail::OperandStorage::resize(Operation *owner, @@ -149,20 +200,6 @@ return newOperands; } -/// Erase an operand held by the storage. -void detail::OperandStorage::eraseOperand(unsigned index) { - assert(index < size()); - TrailingOperandStorage &storage = getStorage(); - MutableArrayRef operands = storage.getOperands(); - --storage.numOperands; - - // Shift all operands down by 1 if the operand to remove is not at the end. - auto indexIt = std::next(operands.begin(), index); - if (index != storage.numOperands) - std::rotate(indexIt, std::next(indexIt), operands.end()); - operands[storage.numOperands].~OpOperand(); -} - //===----------------------------------------------------------------------===// // ResultStorage //===----------------------------------------------------------------------===// @@ -235,6 +272,82 @@ return base->getOperandNumber(); } +//===----------------------------------------------------------------------===// +// MutableOperandRange + +/// Construct a new mutable range from the given operand, operand start index, +/// and range length. +MutableOperandRange::MutableOperandRange( + Operation *owner, unsigned start, unsigned length, + ArrayRef operandSegments) + : owner(owner), start(start), length(length), + operandSegments(operandSegments.begin(), operandSegments.end()) { + assert((start + length) <= owner->getNumOperands() && "invalid range"); +} +MutableOperandRange::MutableOperandRange(Operation *owner) + : MutableOperandRange(owner, /*start=*/0, owner->getNumOperands()) {} + +/// Append the given values to the range. +void MutableOperandRange::append(ValueRange values) { + if (values.empty()) + return; + owner->insertOperands(start + length, values); + updateLength(length + values.size()); +} + +/// Assign this range to the given values. +void MutableOperandRange::assign(ValueRange values) { + owner->setOperands(start, length, values); + if (length != values.size()) + updateLength(/*newLength=*/values.size()); +} + +/// Assign the range to the given value. +void MutableOperandRange::assign(Value value) { + if (length == 1) { + owner->setOperand(start, value); + } else { + owner->setOperands(start, length, value); + updateLength(/*newLength=*/1); + } +} + +/// Erase the operands within the given sub-range. +void MutableOperandRange::erase(unsigned subStart, unsigned subLen) { + assert((subStart + subLen) <= length && "invalid sub-range"); + if (length == 0) + return; + owner->eraseOperands(start + subStart, subLen); + updateLength(length - subLen); +} + +/// Clear this range and erase all of the operands. +void MutableOperandRange::clear() { + if (length != 0) { + owner->eraseOperands(start, length); + updateLength(/*newLength=*/0); + } +} + +/// Allow implicit conversion to an OperandRange. +MutableOperandRange::operator OperandRange() const { + return owner->getOperands().slice(start, length); +} + +/// Update the length of this range to the one provided. +void MutableOperandRange::updateLength(unsigned newLength) { + length = newLength; + + // Update any of the provided segment attributes. + for (OperandSegment &segment : operandSegments) { + auto attr = segment.second.second.cast(); + SmallVector segments(attr.getValues()); + segments[segment.first] = newLength; + segment.second.second = DenseIntElementsAttr::get(attr.getType(), segments); + owner->setAttr(segment.second.first, segment.second.second); + } +} + //===----------------------------------------------------------------------===// // ResultRange diff --git a/mlir/test/mlir-tblgen/op-decl.td b/mlir/test/mlir-tblgen/op-decl.td --- a/mlir/test/mlir-tblgen/op-decl.td +++ b/mlir/test/mlir-tblgen/op-decl.td @@ -67,6 +67,8 @@ // CHECK: Operation::operand_range getODSOperands(unsigned index); // CHECK: Value a(); // CHECK: Operation::operand_range b(); +// CHECK: ::mlir::MutableOperandRange aMutable(); +// CHECK: ::mlir::MutableOperandRange bMutable(); // CHECK: Operation::result_range getODSResults(unsigned index); // CHECK: Value r(); // CHECK: Region &someRegion(); @@ -119,6 +121,7 @@ // CHECK-LABEL: NS::EOp declarations // CHECK: Value a(); +// CHECK: ::mlir::MutableOperandRange aMutable(); // CHECK: Value b(); // CHECK: static void build(OpBuilder &odsBuilder, OperationState &odsState, /*optional*/Type b, /*optional*/Value a) diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -45,25 +45,23 @@ // {1}: The total number of non-variadic operands/results. // {2}: The total number of variadic operands/results. // {3}: The total number of actual values. -// {4}: The begin iterator of the actual values. -// {5}: "operand" or "result". +// {4}: "operand" or "result". const char *sameVariadicSizeValueRangeCalcCode = R"( bool isVariadic[] = {{{0}}; int prevVariadicCount = 0; for (unsigned i = 0; i < index; ++i) if (isVariadic[i]) ++prevVariadicCount; - // Calculate how many dynamic values a static variadic {5} corresponds to. - // This assumes all static variadic {5}s have the same dynamic value count. + // Calculate how many dynamic values a static variadic {4} corresponds to. + // This assumes all static variadic {4}s have the same dynamic value count. int variadicSize = ({3} - {1}) / {2}; // `index` passed in as the parameter is the static index which counts each - // {5} (variadic or not) as size 1. So here for each previous static variadic - // {5}, we need to offset by (variadicSize - 1) to get where the dynamic - // value pack for this static {5} starts. - int offset = index + (variadicSize - 1) * prevVariadicCount; + // {4} (variadic or not) as size 1. So here for each previous static variadic + // {4}, we need to offset by (variadicSize - 1) to get where the dynamic + // value pack for this static {4} starts. + int start = index + (variadicSize - 1) * prevVariadicCount; int size = isVariadic[index] ? variadicSize : 1; - - return {{std::next({4}, offset), std::next({4}, offset + size)}; + return {{start, size}; )"; // The logic to calculate the actual value range for a declared operand/result @@ -72,14 +70,23 @@ // (variadic or not). // // {0}: The name of the attribute specifying the segment sizes. -// {1}: The begin iterator of the actual values. const char *attrSizedSegmentValueRangeCalcCode = R"( auto sizeAttr = getAttrOfType("{0}"); unsigned start = 0; for (unsigned i = 0; i < index; ++i) start += (*(sizeAttr.begin() + i)).getZExtValue(); - unsigned end = start + (*(sizeAttr.begin() + index)).getZExtValue(); - return {{std::next({1}, start), std::next({1}, end)}; + unsigned size = (*(sizeAttr.begin() + index)).getZExtValue(); + return {{start, size}; +)"; + +// The logic to build a range of either operand or result values. +// +// {0}: The begin iterator of the actual values. +// {1}: The call to generate the start and length of the value range. +const char *valueRangeReturnCode = R"( + auto valueRange = {1}; + return {{std::next({0}, valueRange.first), + std::next({0}, valueRange.first + valueRange.second)}; )"; static const char *const opCommentHeader = R"( @@ -177,6 +184,9 @@ // Generates getters for named operands. void genNamedOperandGetters(); + // Generates setters for named operands. + void genNamedOperandSetters(); + // Generates getters for named results. void genNamedResultGetters(); @@ -310,6 +320,7 @@ genOpAsmInterface(); genOpNameGetter(); genNamedOperandGetters(); + genNamedOperandSetters(); genNamedResultGetters(); genNamedRegionGetters(); genNamedSuccessorGetters(); @@ -478,6 +489,37 @@ } } +// Generates the code to compute the start and end index of an operand or result +// range. +template +static void +generateValueRangeStartAndEnd(Class &opClass, StringRef methodName, + int numVariadic, int numNonVariadic, + StringRef rangeSizeCall, bool hasAttrSegmentSize, + StringRef segmentSizeAttr, RangeT &&odsValues) { + auto &method = opClass.newMethod("std::pair", methodName, + "unsigned index"); + + if (numVariadic == 0) { + method.body() << " return {index, 1};\n"; + } else if (hasAttrSegmentSize) { + method.body() << formatv(attrSizedSegmentValueRangeCalcCode, + segmentSizeAttr); + } else { + // Because the op can have arbitrarily interleaved variadic and non-variadic + // operands, we need to embed a list in the "sink" getter method for + // calculation at run-time. + llvm::SmallVector isVariadic; + isVariadic.reserve(llvm::size(odsValues)); + for (auto &it : odsValues) + isVariadic.push_back(it.isVariableLength() ? "true" : "false"); + std::string isVariadicList = llvm::join(isVariadic, ", "); + method.body() << formatv(sameVariadicSizeValueRangeCalcCode, isVariadicList, + numNonVariadic, numVariadic, rangeSizeCall, + "operand"); + } +} + // Generates the named operand getter methods for the given Operator `op` and // puts them in `opClass`. Uses `rangeType` as the return type of getters that // return a range of operands (individual operands are `Value ` and each @@ -519,32 +561,16 @@ "'SameVariadicOperandSize' traits"); } - // First emit a "sink" getter method upon which we layer all nicer named + // First emit a few "sink" getter methods upon which we layer all nicer named // getter methods. - auto &m = opClass.newMethod(rangeType, "getODSOperands", "unsigned index"); + generateValueRangeStartAndEnd( + opClass, "getODSOperandIndexAndLength", numVariadicOperands, + numNormalOperands, rangeSizeCall, attrSizedOperands, + "operand_segment_sizes", const_cast(op).getOperands()); - if (numVariadicOperands == 0) { - // We still need to match the return type, which is a range. - m.body() << " return {std::next(" << rangeBeginCall - << ", index), std::next(" << rangeBeginCall << ", index + 1)};"; - } else if (attrSizedOperands) { - m.body() << formatv(attrSizedSegmentValueRangeCalcCode, - "operand_segment_sizes", rangeBeginCall); - } else { - // Because the op can have arbitrarily interleaved variadic and non-variadic - // operands, we need to embed a list in the "sink" getter method for - // calculation at run-time. - llvm::SmallVector isVariadic; - isVariadic.reserve(numOperands); - for (int i = 0; i < numOperands; ++i) - isVariadic.push_back(op.getOperand(i).isVariableLength() ? "true" - : "false"); - std::string isVariadicList = llvm::join(isVariadic, ", "); - - m.body() << formatv(sameVariadicSizeValueRangeCalcCode, isVariadicList, - numNormalOperands, numVariadicOperands, rangeSizeCall, - rangeBeginCall, "operand"); - } + auto &m = opClass.newMethod(rangeType, "getODSOperands", "unsigned index"); + m.body() << formatv(valueRangeReturnCode, rangeBeginCall, + "getODSOperandIndexAndLength(index)"); // Then we emit nicer named getter methods by redirecting to the "sink" getter // method. @@ -579,6 +605,26 @@ /*getOperandCallPattern=*/"getOperation()->getOperand({0})"); } +void OpEmitter::genNamedOperandSetters() { + auto *attrSizedOperands = op.getTrait("OpTrait::AttrSizedOperandSegments"); + for (int i = 0, e = op.getNumOperands(); i != e; ++i) { + const auto &operand = op.getOperand(i); + if (operand.name.empty()) + continue; + auto &m = opClass.newMethod("::mlir::MutableOperandRange", + (operand.name + "Mutable").str()); + auto &body = m.body(); + body << " auto range = getODSOperandIndexAndLength(" << i << ");\n" + << " return ::mlir::MutableOperandRange(getOperation(), " + "range.first, range.second"; + if (attrSizedOperands) + body << ", ::mlir::MutableOperandRange::OperandSegment(" << i + << "u, *getOperation()->getAttrList().getNamed(" + "\"operand_segment_sizes\"))"; + body << ");\n"; + } +} + void OpEmitter::genNamedResultGetters() { const int numResults = op.getNumResults(); const int numVariadicResults = op.getNumVariableLengthResults(); @@ -607,29 +653,14 @@ "'SameVariadicResultSize' traits"); } + generateValueRangeStartAndEnd( + opClass, "getODSResultIndexAndLength", numVariadicResults, + numNormalResults, "getOperation()->getNumResults()", attrSizedResults, + "result_segment_sizes", op.getResults()); auto &m = opClass.newMethod("Operation::result_range", "getODSResults", "unsigned index"); - - if (numVariadicResults == 0) { - m.body() << " return {std::next(getOperation()->result_begin(), index), " - "std::next(getOperation()->result_begin(), index + 1)};"; - } else if (attrSizedResults) { - m.body() << formatv(attrSizedSegmentValueRangeCalcCode, - "result_segment_sizes", - "getOperation()->result_begin()"); - } else { - llvm::SmallVector isVariadic; - isVariadic.reserve(numResults); - for (int i = 0; i < numResults; ++i) - isVariadic.push_back(op.getResult(i).isVariableLength() ? "true" - : "false"); - std::string isVariadicList = llvm::join(isVariadic, ", "); - - m.body() << formatv(sameVariadicSizeValueRangeCalcCode, isVariadicList, - numNormalResults, numVariadicResults, - "getOperation()->getNumResults()", - "getOperation()->result_begin()", "result"); - } + m.body() << formatv(valueRangeReturnCode, "getOperation()->result_begin()", + "getODSResultIndexAndLength(index)"); for (int i = 0; i != numResults; ++i) { const auto &result = op.getResult(i); diff --git a/mlir/unittests/IR/OperationSupportTest.cpp b/mlir/unittests/IR/OperationSupportTest.cpp --- a/mlir/unittests/IR/OperationSupportTest.cpp +++ b/mlir/unittests/IR/OperationSupportTest.cpp @@ -33,7 +33,7 @@ Value operand = useOp->getResult(0); // Create a non-resizable operation with one operand. - Operation *user = createOp(&context, operand, builder.getIntegerType(16)); + Operation *user = createOp(&context, operand); // The same number of operands is okay. user->setOperands(operand); @@ -57,7 +57,7 @@ Value operand = useOp->getResult(0); // Create a resizable operation with one operand. - Operation *user = createOp(&context, operand, builder.getIntegerType(16)); + Operation *user = createOp(&context, operand); // The same number of operands is okay. user->setOperands(operand); @@ -76,4 +76,77 @@ useOp->destroy(); } +TEST(OperandStorageTest, RangeReplace) { + MLIRContext context; + Builder builder(&context); + + Operation *useOp = + createOp(&context, /*operands=*/llvm::None, builder.getIntegerType(16)); + Value operand = useOp->getResult(0); + + // Create a resizable operation with one operand. + Operation *user = createOp(&context, operand); + + // Check setting with the same number of operands. + user->setOperands(/*start=*/0, /*length=*/1, operand); + EXPECT_EQ(user->getNumOperands(), 1u); + + // Check setting with more operands. + user->setOperands(/*start=*/0, /*length=*/1, {operand, operand, operand}); + EXPECT_EQ(user->getNumOperands(), 3u); + + // Check setting with less operands. + user->setOperands(/*start=*/1, /*length=*/2, {operand}); + EXPECT_EQ(user->getNumOperands(), 2u); + + // Check inserting without replacing operands. + user->setOperands(/*start=*/2, /*length=*/0, {operand}); + EXPECT_EQ(user->getNumOperands(), 3u); + + // Check erasing operands. + user->setOperands(/*start=*/0, /*length=*/3, {}); + EXPECT_EQ(user->getNumOperands(), 0u); + + // Destroy the operations. + user->destroy(); + useOp->destroy(); +} + +TEST(OperandStorageTest, MutableRange) { + MLIRContext context; + Builder builder(&context); + + Operation *useOp = + createOp(&context, /*operands=*/llvm::None, builder.getIntegerType(16)); + Value operand = useOp->getResult(0); + + // Create a resizable operation with one operand. + Operation *user = createOp(&context, operand); + + // Check setting with the same number of operands. + MutableOperandRange mutableOperands(user); + mutableOperands.assign(operand); + EXPECT_EQ(mutableOperands.size(), 1u); + EXPECT_EQ(user->getNumOperands(), 1u); + + // Check setting with more operands. + mutableOperands.assign({operand, operand, operand}); + EXPECT_EQ(mutableOperands.size(), 3u); + EXPECT_EQ(user->getNumOperands(), 3u); + + // Check with inserting a new operand. + mutableOperands.append({operand, operand}); + EXPECT_EQ(mutableOperands.size(), 5u); + EXPECT_EQ(user->getNumOperands(), 5u); + + // Check erasing operands. + mutableOperands.clear(); + EXPECT_EQ(mutableOperands.size(), 0u); + EXPECT_EQ(user->getNumOperands(), 0u); + + // Destroy the operations. + user->destroy(); + useOp->destroy(); +} + } // end namespace