diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -205,6 +205,11 @@ /// 'operands'. void setOperands(ValueRange operands); + /// Replace the operands beginning at 'start' and ending at 'start' + 'length' + /// with the ones provided in 'operands'. 'operands' may be smaller or larger + /// than the range pointed to by 'start'+'length'. + void setOperands(unsigned start, unsigned length, ValueRange operands); + unsigned getNumOperands() { return LLVM_LIKELY(hasOperandStorage) ? getOperandStorage().size() : 0; } @@ -214,6 +219,15 @@ return getOpOperand(idx).set(value); } + /// Erase the operand at position `idx`. + void eraseOperand(unsigned idx) { eraseOperands(idx); } + + /// Erase the operands starting at position `idx` and ending at position + /// 'idx'+'length'. + void eraseOperands(unsigned idx, unsigned length = 1) { + getOperandStorage().eraseOperands(idx, length); + } + // Support operand iteration. using operand_range = OperandRange; using operand_iterator = operand_range::iterator; @@ -221,12 +235,9 @@ operand_iterator operand_begin() { return getOperands().begin(); } operand_iterator operand_end() { return getOperands().end(); } - /// Returns an iterator on the underlying Value's (Value ). + /// Returns an iterator on the underlying Value's. operand_range getOperands() { return operand_range(this); } - /// Erase the operand at position `idx`. - void eraseOperand(unsigned idx) { getOperandStorage().eraseOperand(idx); } - MutableArrayRef getOpOperands() { return LLVM_LIKELY(hasOperandStorage) ? getOperandStorage().getOperands() : MutableArrayRef(); diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -369,8 +369,14 @@ /// 'values'. void setOperands(Operation *owner, ValueRange values); - /// Erase an operand held by the storage. - void eraseOperand(unsigned index); + /// Replace the operands beginning at 'start' and ending at 'start' + 'length' + /// with the ones provided in 'operands'. 'operands' may be smaller or larger + /// than the range pointed to by 'start'+'length'. + void setOperands(Operation *owner, unsigned start, unsigned length, + ValueRange operands); + + /// Erase the operands held by the storage within the given range. + void eraseOperands(unsigned start, unsigned length); /// Get the operation operands held by the storage. MutableArrayRef getOperands() { diff --git a/mlir/include/mlir/IR/UseDefLists.h b/mlir/include/mlir/IR/UseDefLists.h --- a/mlir/include/mlir/IR/UseDefLists.h +++ b/mlir/include/mlir/IR/UseDefLists.h @@ -164,7 +164,8 @@ other.back = nullptr; nextUse = nullptr; back = nullptr; - insertIntoCurrent(); + if (value) + insertIntoCurrent(); return *this; } diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -243,6 +243,18 @@ assert(operands.empty() && "setting operands without an operand storage"); } +/// Replace the operands beginning at 'start' and ending at 'start' + 'length' +/// with the ones provided in 'operands'. 'operands' may be smaller or larger +/// than the range pointed to by 'start'+'length'. +void Operation::setOperands(unsigned start, unsigned length, + ValueRange operands) { + assert((start + length) <= getNumOperands() && + "invalid operand range specified"); + if (LLVM_LIKELY(hasOperandStorage)) + return getOperandStorage().setOperands(this, start, length, operands); + assert(operands.empty() && "setting operands without an operand storage"); +} + //===----------------------------------------------------------------------===// // Diagnostics //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -89,6 +89,55 @@ storageOperands[i].set(values[i]); } +/// Replace the operands beginning at 'start' and ending at 'start' + 'length' +/// with the ones provided in 'operands'. 'operands' may be smaller or larger +/// than the range pointed to by 'start'+'length'. +void detail::OperandStorage::setOperands(Operation *owner, unsigned start, + unsigned length, ValueRange operands) { + // If the new size is the same, we can update inplace. + unsigned newSize = operands.size(); + if (newSize == length) { + MutableArrayRef storageOperands = getOperands(); + for (unsigned i = 0, e = length; i != e; ++i) + storageOperands[start + i].set(operands[i]); + return; + } + // If the new size is greater, remove the extra operands and set the rest + // inplace. + if (newSize < length) { + eraseOperands(start + operands.size(), length - newSize); + setOperands(owner, start, newSize, operands); + return; + } + // Otherwise, the new size is greater so we need to grow the storage. + auto storageOperands = resize(owner, size() + (newSize - length)); + + // Shift operands to the right to make space for the new operands. + unsigned rotateSize = storageOperands.size() - (start + length); + auto rbegin = storageOperands.rbegin(); + std::rotate(rbegin, std::next(rbegin, newSize - length), rbegin + rotateSize); + + // Update the operands inplace. + for (unsigned i = 0, e = operands.size(); i != e; ++i) + storageOperands[start + i].set(operands[i]); +} + +/// Erase an operand held by the storage. +void detail::OperandStorage::eraseOperands(unsigned start, unsigned length) { + TrailingOperandStorage &storage = getStorage(); + MutableArrayRef operands = storage.getOperands(); + assert((start + length) <= operands.size()); + storage.numOperands -= length; + + // Shift all operands down if the operand to remove is not at the end. + if (start != storage.numOperands) { + auto indexIt = std::next(operands.begin(), start); + std::rotate(indexIt, std::next(indexIt, length), operands.end()); + } + for (unsigned i = 0; i != length; ++i) + operands[storage.numOperands + i].~OpOperand(); +} + /// Resize the storage to the given size. Returns the array containing the new /// operands. MutableArrayRef detail::OperandStorage::resize(Operation *owner, @@ -149,20 +198,6 @@ return newOperands; } -/// Erase an operand held by the storage. -void detail::OperandStorage::eraseOperand(unsigned index) { - assert(index < size()); - TrailingOperandStorage &storage = getStorage(); - MutableArrayRef operands = storage.getOperands(); - --storage.numOperands; - - // Shift all operands down by 1 if the operand to remove is not at the end. - auto indexIt = std::next(operands.begin(), index); - if (index != storage.numOperands) - std::rotate(indexIt, std::next(indexIt), operands.end()); - operands[storage.numOperands].~OpOperand(); -} - //===----------------------------------------------------------------------===// // ResultStorage //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-tblgen/op-decl.td b/mlir/test/mlir-tblgen/op-decl.td --- a/mlir/test/mlir-tblgen/op-decl.td +++ b/mlir/test/mlir-tblgen/op-decl.td @@ -67,6 +67,8 @@ // CHECK: Operation::operand_range getODSOperands(unsigned index); // CHECK: Value a(); // CHECK: Operation::operand_range b(); +// CHECK: void a(Value value); +// CHECK: void b(ValueRange values); // CHECK: Operation::result_range getODSResults(unsigned index); // CHECK: Value r(); // CHECK: Region &someRegion(); @@ -119,6 +121,7 @@ // CHECK-LABEL: NS::EOp declarations // CHECK: Value a(); +// CHECK: void a(Value value); // CHECK: Value b(); // CHECK: static void build(Builder *odsBuilder, OperationState &odsState, /*optional*/Type b, /*optional*/Value a) diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -45,25 +45,23 @@ // {1}: The total number of non-variadic operands/results. // {2}: The total number of variadic operands/results. // {3}: The total number of actual values. -// {4}: The begin iterator of the actual values. -// {5}: "operand" or "result". +// {4}: "operand" or "result". const char *sameVariadicSizeValueRangeCalcCode = R"( bool isVariadic[] = {{{0}}; int prevVariadicCount = 0; for (unsigned i = 0; i < index; ++i) if (isVariadic[i]) ++prevVariadicCount; - // Calculate how many dynamic values a static variadic {5} corresponds to. - // This assumes all static variadic {5}s have the same dynamic value count. + // Calculate how many dynamic values a static variadic {4} corresponds to. + // This assumes all static variadic {4}s have the same dynamic value count. int variadicSize = ({3} - {1}) / {2}; // `index` passed in as the parameter is the static index which counts each - // {5} (variadic or not) as size 1. So here for each previous static variadic - // {5}, we need to offset by (variadicSize - 1) to get where the dynamic - // value pack for this static {5} starts. - int offset = index + (variadicSize - 1) * prevVariadicCount; + // {4} (variadic or not) as size 1. So here for each previous static variadic + // {4}, we need to offset by (variadicSize - 1) to get where the dynamic + // value pack for this static {4} starts. + int start = index + (variadicSize - 1) * prevVariadicCount; int size = isVariadic[index] ? variadicSize : 1; - - return {{std::next({4}, offset), std::next({4}, offset + size)}; + return {{start, size}; )"; // The logic to calculate the actual value range for a declared operand/result @@ -72,14 +70,23 @@ // (variadic or not). // // {0}: The name of the attribute specifying the segment sizes. -// {1}: The begin iterator of the actual values. const char *attrSizedSegmentValueRangeCalcCode = R"( auto sizeAttr = getAttrOfType("{0}"); unsigned start = 0; for (unsigned i = 0; i < index; ++i) start += (*(sizeAttr.begin() + i)).getZExtValue(); - unsigned end = start + (*(sizeAttr.begin() + index)).getZExtValue(); - return {{std::next({1}, start), std::next({1}, end)}; + unsigned size = (*(sizeAttr.begin() + index)).getZExtValue(); + return {{start, size}; +)"; + +// The logic to build a range of either operand or result values. +// +// {0}: The begin iterator of the actual values. +// {1}: The call to generate the start and length of the value range. +const char *valueRangeReturnCode = R"( + auto valueRange = {1}; + return {{std::next({0}, valueRange.first), + std::next({0}, valueRange.first + valueRange.second)}; )"; static const char *const opCommentHeader = R"( @@ -177,6 +184,9 @@ // Generates getters for named operands. void genNamedOperandGetters(); + // Generates setters for named operands. + void genNamedOperandSetters(); + // Generates getters for named results. void genNamedResultGetters(); @@ -310,6 +320,7 @@ genOpAsmInterface(); genOpNameGetter(); genNamedOperandGetters(); + genNamedOperandSetters(); genNamedResultGetters(); genNamedRegionGetters(); genNamedSuccessorGetters(); @@ -478,6 +489,37 @@ } } +// Generates the code to compute the start and end index of an operand or result +// range. +template +static void +generateValueRangeStartAndEnd(Class &opClass, StringRef methodName, + int numVariadic, int numNonVariadic, + StringRef rangeSizeCall, bool hasAttrSegmentSize, + StringRef segmentSizeAttr, RangeT &&odsValues) { + auto &method = opClass.newMethod("std::pair", methodName, + "unsigned index"); + + if (numVariadic == 0) { + method.body() << " return {index, 1};\n"; + } else if (hasAttrSegmentSize) { + method.body() << formatv(attrSizedSegmentValueRangeCalcCode, + segmentSizeAttr); + } else { + // Because the op can have arbitrarily interleaved variadic and non-variadic + // operands, we need to embed a list in the "sink" getter method for + // calculation at run-time. + llvm::SmallVector isVariadic; + isVariadic.reserve(llvm::size(odsValues)); + for (auto &it : odsValues) + isVariadic.push_back(it.isVariableLength() ? "true" : "false"); + std::string isVariadicList = llvm::join(isVariadic, ", "); + method.body() << formatv(sameVariadicSizeValueRangeCalcCode, isVariadicList, + numNonVariadic, numVariadic, rangeSizeCall, + "operand"); + } +} + // Generates the named operand getter methods for the given Operator `op` and // puts them in `opClass`. Uses `rangeType` as the return type of getters that // return a range of operands (individual operands are `Value ` and each @@ -519,32 +561,16 @@ "'SameVariadicOperandSize' traits"); } - // First emit a "sink" getter method upon which we layer all nicer named + // First emit a few "sink" getter methods upon which we layer all nicer named // getter methods. - auto &m = opClass.newMethod(rangeType, "getODSOperands", "unsigned index"); - - if (numVariadicOperands == 0) { - // We still need to match the return type, which is a range. - m.body() << " return {std::next(" << rangeBeginCall - << ", index), std::next(" << rangeBeginCall << ", index + 1)};"; - } else if (attrSizedOperands) { - m.body() << formatv(attrSizedSegmentValueRangeCalcCode, - "operand_segment_sizes", rangeBeginCall); - } else { - // Because the op can have arbitrarily interleaved variadic and non-variadic - // operands, we need to embed a list in the "sink" getter method for - // calculation at run-time. - llvm::SmallVector isVariadic; - isVariadic.reserve(numOperands); - for (int i = 0; i < numOperands; ++i) - isVariadic.push_back(op.getOperand(i).isVariableLength() ? "true" - : "false"); - std::string isVariadicList = llvm::join(isVariadic, ", "); + generateValueRangeStartAndEnd( + opClass, "getODSOperandIndexAndLength", numVariadicOperands, + numNormalOperands, rangeSizeCall, attrSizedOperands, + "operand_segment_sizes", const_cast(op).getOperands()); - m.body() << formatv(sameVariadicSizeValueRangeCalcCode, isVariadicList, - numNormalOperands, numVariadicOperands, rangeSizeCall, - rangeBeginCall, "operand"); - } + auto &m = opClass.newMethod(rangeType, "getODSOperands", "unsigned index"); + m.body() << formatv(valueRangeReturnCode, rangeBeginCall, + "getODSOperandIndexAndLength(index)"); // Then we emit nicer named getter methods by redirecting to the "sink" getter // method. @@ -579,6 +605,50 @@ /*getOperandCallPattern=*/"getOperation()->getOperand({0})"); } +void OpEmitter::genNamedOperandSetters() { + const auto *attrSizedOperands = + op.getTrait("OpTrait::AttrSizedOperandSegments"); + // Code block for updating the segment size attribute if necessary. + const char *updateAttrSegmentsCode = R"( + // If the sizes are the same, there is nothing to update. + if (values.size() == range.second) + return; + auto attr = getAttrOfType("operand_segment_sizes"); + SmallVector segments(attr.getValues()); + segments[{0}] = values.size(); + setAttr("operand_segment_sizes", + DenseIntElementsAttr::get(attr.getType(), segments)); + )"; + + for (int i = 0, e = op.getNumOperands(); i != e; ++i) { + const auto &operand = op.getOperand(i); + if (operand.name.empty()) + continue; + // If the operand is statically sized, set it directly. + if (!operand.isVariableLength()) { + auto &m = opClass.newMethod("void", operand.name, "Value value"); + m.body() << " auto range = getODSOperandIndexAndLength(" << i << ");\n" + << " this->getOperation()->setOperand(range.first, value);\n"; + continue; + } + + // Otherwise, replace the current range with the new and update any segment + // size attributes. + OpMethod *m = nullptr; + if (operand.isOptional()) { + m = &opClass.newMethod("void", operand.name, "Value value"); + m->body() << " ::llvm::ArrayRef values(&value, value ? 1 : 0);\n"; + } else if (operand.isVariadic()) { + m = &opClass.newMethod("void", operand.name, "ValueRange values"); + } + m->body() << " auto range = getODSOperandIndexAndLength(" << i << ");\n" + << " this->getOperation()->setOperands(range.first, " + "range.second, values);\n"; + if (attrSizedOperands) + m->body() << llvm::formatv(updateAttrSegmentsCode, i); + } +} + void OpEmitter::genNamedResultGetters() { const int numResults = op.getNumResults(); const int numVariadicResults = op.getNumVariableLengthResults(); @@ -607,29 +677,14 @@ "'SameVariadicResultSize' traits"); } + generateValueRangeStartAndEnd( + opClass, "getODSResultIndexAndLength", numVariadicResults, + numNormalResults, "getOperation()->getNumResults()", attrSizedResults, + "result_segment_sizes", op.getResults()); auto &m = opClass.newMethod("Operation::result_range", "getODSResults", "unsigned index"); - - if (numVariadicResults == 0) { - m.body() << " return {std::next(getOperation()->result_begin(), index), " - "std::next(getOperation()->result_begin(), index + 1)};"; - } else if (attrSizedResults) { - m.body() << formatv(attrSizedSegmentValueRangeCalcCode, - "result_segment_sizes", - "getOperation()->result_begin()"); - } else { - llvm::SmallVector isVariadic; - isVariadic.reserve(numResults); - for (int i = 0; i < numResults; ++i) - isVariadic.push_back(op.getResult(i).isVariableLength() ? "true" - : "false"); - std::string isVariadicList = llvm::join(isVariadic, ", "); - - m.body() << formatv(sameVariadicSizeValueRangeCalcCode, isVariadicList, - numNormalResults, numVariadicResults, - "getOperation()->getNumResults()", - "getOperation()->result_begin()", "result"); - } + m.body() << formatv(valueRangeReturnCode, "getOperation()->result_begin()", + "getODSResultIndexAndLength(index)"); for (int i = 0; i != numResults; ++i) { const auto &result = op.getResult(i); diff --git a/mlir/unittests/IR/OperationSupportTest.cpp b/mlir/unittests/IR/OperationSupportTest.cpp --- a/mlir/unittests/IR/OperationSupportTest.cpp +++ b/mlir/unittests/IR/OperationSupportTest.cpp @@ -33,7 +33,7 @@ Value operand = useOp->getResult(0); // Create a non-resizable operation with one operand. - Operation *user = createOp(&context, operand, builder.getIntegerType(16)); + Operation *user = createOp(&context, operand); // The same number of operands is okay. user->setOperands(operand); @@ -57,7 +57,7 @@ Value operand = useOp->getResult(0); // Create a resizable operation with one operand. - Operation *user = createOp(&context, operand, builder.getIntegerType(16)); + Operation *user = createOp(&context, operand); // The same number of operands is okay. user->setOperands(operand); @@ -76,4 +76,40 @@ useOp->destroy(); } +TEST(OperandStorageTest, RangeReplace) { + MLIRContext context; + Builder builder(&context); + + Operation *useOp = + createOp(&context, /*operands=*/llvm::None, builder.getIntegerType(16)); + Value operand = useOp->getResult(0); + + // Create a resizable operation with one operand. + Operation *user = createOp(&context, operand); + + // Check setting with the same number of operands. + user->setOperands(/*start=*/0, /*length=*/1, operand); + EXPECT_EQ(user->getNumOperands(), 1u); + + // Check setting with more operands. + user->setOperands(/*start=*/0, /*length=*/1, {operand, operand, operand}); + EXPECT_EQ(user->getNumOperands(), 3u); + + // Check setting with less operands. + user->setOperands(/*start=*/1, /*length=*/2, {operand}); + EXPECT_EQ(user->getNumOperands(), 2u); + + // Check inserting without replacing operands. + user->setOperands(/*start=*/2, /*length=*/0, {operand}); + EXPECT_EQ(user->getNumOperands(), 3u); + + // Check erasing operands. + user->setOperands(/*start=*/0, /*length=*/3, {}); + EXPECT_EQ(user->getNumOperands(), 0u); + + // Destroy the operations. + user->destroy(); + useOp->destroy(); +} + } // end namespace