diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -13,6 +13,7 @@ #include "SubElementInterfaces.h" namespace llvm { +class BitVector; struct fltSemantics; } // namespace llvm diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -166,8 +166,8 @@ TypeRange resultTypes); /// Returns a new function type without the specified arguments and results. - FunctionType getWithoutArgsAndResults(ArrayRef argIndices, - ArrayRef resultIndices); + FunctionType getWithoutArgsAndResults(const llvm::BitVector &argIndices, + const llvm::BitVector &resultIndices); }]; } diff --git a/mlir/include/mlir/IR/FunctionInterfaces.h b/mlir/include/mlir/IR/FunctionInterfaces.h --- a/mlir/include/mlir/IR/FunctionInterfaces.h +++ b/mlir/include/mlir/IR/FunctionInterfaces.h @@ -16,6 +16,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpDefinition.h" +#include "llvm/ADT/BitVector.h" #include "llvm/ADT/SmallString.h" namespace mlir { @@ -82,12 +83,12 @@ unsigned originalNumResults, Type newType); /// Erase the specified arguments and update the function type attribute. -void eraseFunctionArguments(Operation *op, ArrayRef argIndices, - unsigned originalNumArgs, Type newType); +void eraseFunctionArguments(Operation *op, const llvm::BitVector &argIndices, + Type newType); /// Erase the specified results and update the function type attribute. -void eraseFunctionResults(Operation *op, ArrayRef resultIndices, - unsigned originalNumResults, Type newType); +void eraseFunctionResults(Operation *op, const llvm::BitVector &resultIndices, + Type newType); /// Set a FunctionOpInterface operation's type signature. void setFunctionType(Operation *op, Type newType); @@ -100,7 +101,7 @@ /// Filters out any elements referenced by `indices`. If any types are removed, /// `storage` is used to hold the new type list. Returns the new type list. -TypeRange filterTypesOut(TypeRange types, ArrayRef indices, +TypeRange filterTypesOut(TypeRange types, const llvm::BitVector &indices, SmallVectorImpl &storage); //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/FunctionInterfaces.td b/mlir/include/mlir/IR/FunctionInterfaces.td --- a/mlir/include/mlir/IR/FunctionInterfaces.td +++ b/mlir/include/mlir/IR/FunctionInterfaces.td @@ -280,27 +280,31 @@ } /// Erase a single argument at `argIndex`. - void eraseArgument(unsigned argIndex) { eraseArguments({argIndex}); } + void eraseArgument(unsigned argIndex) { + llvm::BitVector argsToErase($_op.getNumArguments()); + argsToErase.set(argIndex); + eraseArguments(argsToErase); + } /// Erases the arguments listed in `argIndices`. - /// `argIndices` is allowed to have duplicates and can be in any order. - void eraseArguments(ArrayRef argIndices) { - unsigned originalNumArgs = $_op.getNumArguments(); - Type newType = $_op.getTypeWithoutArgsAndResults(argIndices, {}); - function_interface_impl::eraseFunctionArguments(this->getOperation(), argIndices, - originalNumArgs, newType); + void eraseArguments(const llvm::BitVector &argIndices) { + Type newType = $_op.getTypeWithoutArgs(argIndices); + function_interface_impl::eraseFunctionArguments( + this->getOperation(), argIndices, newType); } /// Erase a single result at `resultIndex`. - void eraseResult(unsigned resultIndex) { eraseResults({resultIndex}); } + void eraseResult(unsigned resultIndex) { + llvm::BitVector resultsToErase($_op.getNumResults()); + resultsToErase.set(resultIndex); + eraseResults(resultsToErase); + } /// Erases the results listed in `resultIndices`. - /// `resultIndices` is allowed to have duplicates and can be in any order. - void eraseResults(ArrayRef resultIndices) { - unsigned originalNumResults = $_op.getNumResults(); - Type newType = $_op.getTypeWithoutArgsAndResults({}, resultIndices); + void eraseResults(const llvm::BitVector &resultIndices) { + Type newType = $_op.getTypeWithoutResults(resultIndices); function_interface_impl::eraseFunctionResults( - this->getOperation(), resultIndices, originalNumResults, newType); + this->getOperation(), resultIndices, newType); } /// Return the type of this function with the specified arguments and @@ -320,10 +324,9 @@ /// Return the type of this function without the specified arguments and /// results. This is used to update the function's signature in the - /// `eraseArguments` and `eraseResults` methods. The arrays of indices are - /// allowed to have duplicates and can be in any order. + /// `eraseArguments` and `eraseResults` methods. Type getTypeWithoutArgsAndResults( - ArrayRef argIndices, ArrayRef resultIndices) { + const llvm::BitVector &argIndices, const llvm::BitVector &resultIndices) { SmallVector argStorage, resultStorage; TypeRange newArgTypes = function_interface_impl::filterTypesOut( $_op.getArgumentTypes(), argIndices, argStorage); @@ -331,6 +334,18 @@ $_op.getResultTypes(), resultIndices, resultStorage); return $_op.cloneTypeWith(newArgTypes, newResultTypes); } + Type getTypeWithoutArgs(const llvm::BitVector &argIndices) { + SmallVector argStorage; + TypeRange newArgTypes = function_interface_impl::filterTypesOut( + $_op.getArgumentTypes(), argIndices, argStorage); + return $_op.cloneTypeWith(newArgTypes, $_op.getResultTypes()); + } + Type getTypeWithoutResults(const llvm::BitVector &resultIndices) { + SmallVector resultStorage; + TypeRange newResultTypes = function_interface_impl::filterTypesOut( + $_op.getResultTypes(), resultIndices, resultStorage); + return $_op.cloneTypeWith($_op.getArgumentTypes(), newResultTypes); + } //===------------------------------------------------------------------===// // Argument Attributes diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp @@ -24,10 +24,10 @@ // Collect information about the results will become appended arguments. SmallVector erasedResultTypes; - SmallVector erasedResultIndices; + llvm::BitVector erasedResultIndices(functionType.getNumResults()); for (const auto &resultType : llvm::enumerate(functionType.getResults())) { if (resultType.value().isa()) { - erasedResultIndices.push_back(resultType.index()); + erasedResultIndices.set(resultType.index()); erasedResultTypes.push_back(resultType.value()); } } @@ -40,9 +40,11 @@ func.setType(newFunctionType); // Transfer the result attributes to arg attributes. - for (int i = 0, e = erasedResultTypes.size(); i < e; i++) + auto erasedIndicesIt = erasedResultIndices.set_bits_begin(); + for (int i = 0, e = erasedResultTypes.size(); i < e; ++i, ++erasedIndicesIt) { func.setArgAttrs(functionType.getNumInputs() + i, - func.getResultAttrs(erasedResultIndices[i])); + func.getResultAttrs(*erasedIndicesIt)); + } // Erase the results. func.eraseResults(erasedResultIndices); diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -172,8 +172,8 @@ /// Returns a new function type without the specified arguments and results. FunctionType -FunctionType::getWithoutArgsAndResults(ArrayRef argIndices, - ArrayRef resultIndices) { +FunctionType::getWithoutArgsAndResults(const llvm::BitVector &argIndices, + const llvm::BitVector &resultIndices) { SmallVector argStorage, resultStorage; TypeRange newArgTypes = function_interface_impl::filterTypesOut( getInputs(), argIndices, argStorage); diff --git a/mlir/lib/IR/FunctionInterfaces.cpp b/mlir/lib/IR/FunctionInterfaces.cpp --- a/mlir/lib/IR/FunctionInterfaces.cpp +++ b/mlir/lib/IR/FunctionInterfaces.cpp @@ -7,26 +7,9 @@ //===----------------------------------------------------------------------===// #include "mlir/IR/FunctionInterfaces.h" -#include "mlir/Support/LLVM.h" -#include "llvm/ADT/BitVector.h" using namespace mlir; -/// Helper to call a callback once on each index in the range -/// [0, `totalIndices`), *except* for the indices given in `indices`. -/// `indices` is allowed to have duplicates and can be in any order. -inline static void iterateIndicesExcept(unsigned totalIndices, - ArrayRef indices, - function_ref callback) { - llvm::BitVector skipIndices(totalIndices); - for (unsigned i : indices) - skipIndices.set(i); - - for (unsigned i = 0; i < totalIndices; ++i) - if (!skipIndices.test(i)) - callback(i); -} - //===----------------------------------------------------------------------===// // Tablegen Interface Definitions //===----------------------------------------------------------------------===// @@ -217,8 +200,7 @@ } void mlir::function_interface_impl::eraseFunctionArguments( - Operation *op, ArrayRef argIndices, unsigned originalNumArgs, - Type newType) { + Operation *op, const llvm::BitVector &argIndices, Type newType) { // There are 3 things that need to be updated: // - Function type. // - Arg attrs. @@ -229,9 +211,9 @@ if (auto argAttrs = op->getAttrOfType(getArgDictAttrName())) { SmallVector newArgAttrs; newArgAttrs.reserve(argAttrs.size()); - iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) { - newArgAttrs.emplace_back(argAttrs[i].cast()); - }); + for (unsigned i = 0, e = argIndices.size(); i < e; ++i) + if (!argIndices[i]) + newArgAttrs.emplace_back(argAttrs[i].cast()); setAllArgAttrDicts(op, newArgAttrs); } @@ -241,8 +223,7 @@ } void mlir::function_interface_impl::eraseFunctionResults( - Operation *op, ArrayRef resultIndices, - unsigned originalNumResults, Type newType) { + Operation *op, const llvm::BitVector &resultIndices, Type newType) { // There are 2 things that need to be updated: // - Function type. // - Result attrs. @@ -251,9 +232,9 @@ if (auto resAttrs = op->getAttrOfType(getResultDictAttrName())) { SmallVector newResultAttrs; newResultAttrs.reserve(resAttrs.size()); - iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) { - newResultAttrs.emplace_back(resAttrs[i].cast()); - }); + for (unsigned i = 0, e = resultIndices.size(); i < e; ++i) + if (!resultIndices[i]) + newResultAttrs.emplace_back(resAttrs[i].cast()); setAllResultAttrDicts(op, newResultAttrs); } @@ -282,12 +263,14 @@ TypeRange mlir::function_interface_impl::filterTypesOut(TypeRange types, - ArrayRef indices, + const llvm::BitVector &indices, SmallVectorImpl &storage) { - if (indices.empty()) + if (indices.none()) return types; - iterateIndicesExcept(types.size(), indices, - [&](unsigned i) { storage.emplace_back(types[i]); }); + + for (unsigned i = 0, e = types.size(); i < e; ++i) + if (!indices[i]) + storage.emplace_back(types[i]); return storage; } diff --git a/mlir/test/lib/IR/TestFunc.cpp b/mlir/test/lib/IR/TestFunc.cpp --- a/mlir/test/lib/IR/TestFunc.cpp +++ b/mlir/test/lib/IR/TestFunc.cpp @@ -87,18 +87,10 @@ auto module = getOperation(); for (FuncOp func : module.getOps()) { - SmallVector indicesToErase; - for (auto argIndex : llvm::seq(0, func.getNumArguments())) { - if (func.getArgAttr(argIndex, "test.erase_this_arg")) { - // Push back twice to test that duplicate arg indices are handled - // correctly. - indicesToErase.push_back(argIndex); - indicesToErase.push_back(argIndex); - } - } - // Reverse the order to test that unsorted index lists are handled - // correctly. - std::reverse(indicesToErase.begin(), indicesToErase.end()); + llvm::BitVector indicesToErase(func.getNumArguments()); + for (auto argIndex : llvm::seq(0, func.getNumArguments())) + if (func.getArgAttr(argIndex, "test.erase_this_arg")) + indicesToErase.set(argIndex); func.eraseArguments(indicesToErase); } } @@ -115,18 +107,10 @@ auto module = getOperation(); for (FuncOp func : module.getOps()) { - SmallVector indicesToErase; - for (auto resultIndex : llvm::seq(0, func.getNumResults())) { - if (func.getResultAttr(resultIndex, "test.erase_this_result")) { - // Push back twice to test that duplicate indices are handled - // correctly. - indicesToErase.push_back(resultIndex); - indicesToErase.push_back(resultIndex); - } - } - // Reverse the order to test that unsorted index lists are handled - // correctly. - std::reverse(indicesToErase.begin(), indicesToErase.end()); + llvm::BitVector indicesToErase(func.getNumResults()); + for (auto resultIndex : llvm::seq(0, func.getNumResults())) + if (func.getResultAttr(resultIndex, "test.erase_this_result")) + indicesToErase.set(resultIndex); func.eraseResults(indicesToErase); } }