diff --git a/mlir/docs/Traits.md b/mlir/docs/Traits.md --- a/mlir/docs/Traits.md +++ b/mlir/docs/Traits.md @@ -255,11 +255,17 @@ - they can have argument and result attributes that are stored in dictionary attributes on the operation itself. -This trait does *NOT* provide type support for the functions, meaning that -concrete Ops must handle the type of the declared or defined function. -`getTypeAttrName()` is a convenience function that returns the name of the -attribute that can be used to store the function type, but the trait makes no -assumption based on it. +This trait provides limited type support for the declared or defined functions. +The convenience function `getTypeAttrName()` returns the name of an attribute +that can be used to store the function type. In addition, this trait provides +`getType` and `setType` helpers to store a `FunctionType` in the attribute named +by `getTypeAttrName()`. + +In general, this trait assumes concrete ops use `FunctionType` under the hood. +If this is not the case, in order to use the function type support, concrete ops +must define the following methods, using the same name, to hide the ones defined +for `FunctionType`: `addBodyBlock`, `getType`, `getTypeWithoutArgsAndResults` +and `setType`. ### HasParent diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h --- a/mlir/include/mlir/IR/Block.h +++ b/mlir/include/mlir/IR/Block.h @@ -16,6 +16,10 @@ #include "mlir/IR/BlockSupport.h" #include "mlir/IR/Visitors.h" +namespace llvm { +class BitVector; +} // end namespace llvm + namespace mlir { class TypeRange; template class ValueTypeRange; @@ -98,6 +102,13 @@ /// Erase the argument at 'index' and remove it from the argument list. void eraseArgument(unsigned index); + /// Erases the arguments listed in `argIndices` and removes them from the + /// argument list. + /// `argIndices` is allowed to have duplicates and can be in any order. + void eraseArguments(ArrayRef argIndices); + /// Erases the arguments that have their corresponding bit set in + /// `eraseIndices` and removes them from the argument list. + void eraseArguments(llvm::BitVector eraseIndices); unsigned getNumArguments() { return arguments.size(); } BlockArgument getArgument(unsigned i) { return arguments[i]; } diff --git a/mlir/include/mlir/IR/Function.h b/mlir/include/mlir/IR/Function.h --- a/mlir/include/mlir/IR/Function.h +++ b/mlir/include/mlir/IR/Function.h @@ -59,18 +59,6 @@ void print(OpAsmPrinter &p); LogicalResult verify(); - /// Erase a single argument at `argIndex`. - void eraseArgument(unsigned argIndex) { eraseArguments({argIndex}); } - /// Erases the arguments listed in `argIndices`. - /// `argIndices` is allowed to have duplicates and can be in any order. - void eraseArguments(ArrayRef argIndices); - - /// Erase a single result at `resultIndex`. - void eraseResult(unsigned resultIndex) { eraseResults({resultIndex}); } - /// Erases the results listed in `resultIndices`. - /// `resultIndices` is allowed to have duplicates and can be in any order. - void eraseResults(ArrayRef resultIndices); - /// Create a deep copy of this function and all of its blocks, remapping /// any operands that use values outside of the function using the map that is /// provided (leaving them alone if no entry is present). If the mapper diff --git a/mlir/include/mlir/IR/FunctionSupport.h b/mlir/include/mlir/IR/FunctionSupport.h --- a/mlir/include/mlir/IR/FunctionSupport.h +++ b/mlir/include/mlir/IR/FunctionSupport.h @@ -71,6 +71,14 @@ return resultDict ? resultDict.getValue() : llvm::None; } +/// Erase the specified arguments and update the function type attribute. +void eraseFunctionArguments(Operation *op, ArrayRef argIndices, + unsigned originalNumArgs, Type newType); + +/// Erase the specified results and update the function type attribute. +void eraseFunctionResults(Operation *op, ArrayRef resultIndices, + unsigned originalNumResults, Type newType); + } // namespace impl namespace OpTrait { @@ -84,12 +92,21 @@ /// arguments; /// - they can have argument attributes that are stored in a dictionary /// attribute on the Op itself. -/// This trait does *NOT* provide type support for the functions, meaning that -/// concrete Ops must handle the type of the declared or defined function. -/// `getTypeAttrName()` is a convenience function that returns the name of the -/// attribute that can be used to store the function type, but the trait makes -/// no assumption based on it. /// +/// This trait provides limited type support for the declared or defined +/// functions. The convenience function `getTypeAttrName()` returns the name of +/// an attribute that can be used to store the function type. In addition, this +/// trait provides `getType` and `setType` helpers to store a `FunctionType` in +/// the attribute named by `getTypeAttrName()`. +/// +/// In general, this trait assumes concrete ops use `FunctionType` under the +/// hood. If this is not the case, in order to use the function type support, +/// concrete ops must define the following methods, using the same name, to hide +/// the ones defined for `FunctionType`: `addBodyBlock`, `getType`, +/// `getTypeWithoutArgsAndResults` and `setType`. +/// +/// Besides the requirements above, concrete ops must interact with this trait +/// using the following functions: /// - Concrete ops *must* define a member function `getNumFuncArguments()` that /// returns the number of function arguments based exclusively on type (so /// that it can be called on function declarations). @@ -183,6 +200,19 @@ return getTypeAttr().getValue().template cast(); } + /// Return the type of this function without the specified arguments and + /// results. This is used to update the function's signature in the + /// `eraseArguments` and `eraseResults` methods. The arrays of indices are + /// allowed to have duplicates and can be in any order. + /// + /// Note that the concrete class must define a method with the same name to + /// hide this one if the concrete class does not use FunctionType for the + /// function type under the hood. + FunctionType getTypeWithoutArgsAndResults(ArrayRef argIndices, + ArrayRef resultIndices) { + return getType().getWithoutArgsAndResults(argIndices, resultIndices); + } + bool isTypeAttrValid() { auto typeAttr = getTypeAttr(); if (!typeAttr) @@ -204,7 +234,7 @@ void setType(FunctionType newType); //===--------------------------------------------------------------------===// - // Argument Handling + // Argument and Result Handling //===--------------------------------------------------------------------===// using BlockArgListType = Region::BlockArgListType; @@ -229,6 +259,30 @@ return getBody().getArgumentTypes(); } + /// Erase a single argument at `argIndex`. + void eraseArgument(unsigned argIndex) { eraseArguments({argIndex}); } + + /// Erases the arguments listed in `argIndices`. + /// `argIndices` is allowed to have duplicates and can be in any order. + void eraseArguments(ArrayRef argIndices) { + unsigned originalNumArgs = getNumArguments(); + Type newType = getTypeWithoutArgsAndResults(argIndices, {}); + ::mlir::impl::eraseFunctionArguments(this->getOperation(), argIndices, + originalNumArgs, newType); + } + + /// Erase a single result at `resultIndex`. + void eraseResult(unsigned resultIndex) { eraseResults({resultIndex}); } + + /// Erases the results listed in `resultIndices`. + /// `resultIndices` is allowed to have duplicates and can be in any order. + void eraseResults(ArrayRef resultIndices) { + unsigned originalNumResults = getNumResults(); + Type newType = getTypeWithoutArgsAndResults({}, resultIndices); + ::mlir::impl::eraseFunctionResults(this->getOperation(), resultIndices, + originalNumResults, newType); + } + //===--------------------------------------------------------------------===// // Argument Attributes //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h --- a/mlir/include/mlir/IR/Types.h +++ b/mlir/include/mlir/IR/Types.h @@ -238,15 +238,19 @@ static FunctionType get(TypeRange inputs, TypeRange results, MLIRContext *context); - // Input types. + /// Input types. unsigned getNumInputs() const; Type getInput(unsigned i) const { return getInputs()[i]; } ArrayRef getInputs() const; - // Result types. + /// Result types. unsigned getNumResults() const; Type getResult(unsigned i) const { return getResults()[i]; } ArrayRef getResults() const; + + /// Returns a new function type without the specified arguments and results. + FunctionType getWithoutArgsAndResults(ArrayRef argIndices, + ArrayRef resultIndices); }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -25,8 +25,6 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" -#include - #define DEBUG_TYPE "linalg-drop-unit-dims" using namespace mlir; @@ -166,9 +164,8 @@ for (unsigned unitDimLoop : unitDims) { entryBlock->getArgument(unitDimLoop).replaceAllUsesWith(zero); } - std::set orderedUnitDims(unitDims.begin(), unitDims.end()); - for (unsigned i : llvm::reverse(orderedUnitDims)) - entryBlock->eraseArgument(i); + SmallVector unitDimsToErase(unitDims.begin(), unitDims.end()); + entryBlock->eraseArguments(unitDimsToErase); return success(); } diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp --- a/mlir/lib/IR/Block.cpp +++ b/mlir/lib/IR/Block.cpp @@ -9,6 +9,7 @@ #include "mlir/IR/Block.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Operation.h" +#include "llvm/ADT/BitVector.h" using namespace mlir; //===----------------------------------------------------------------------===// @@ -176,6 +177,22 @@ arguments.erase(arguments.begin() + index); } +void Block::eraseArguments(ArrayRef argIndices) { + llvm::BitVector eraseIndices(getNumArguments()); + for (unsigned i : argIndices) + eraseIndices.set(i); + eraseArguments(eraseIndices); +} + +void Block::eraseArguments(llvm::BitVector eraseIndices) { + // We do this in reverse so that we erase later indices before earlier + // indices, to avoid shifting the later indices. + unsigned originalNumArgs = getNumArguments(); + for (unsigned i = 0; i < originalNumArgs; ++i) + if (eraseIndices.test(originalNumArgs - i - 1)) + eraseArgument(originalNumArgs - i - 1); +} + /// Insert one value to the given position of the argument list. The existing /// arguments are shifted. The block is expected not to have predecessors. BlockArgument Block::insertArgument(args_iterator it, Type type) { diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt --- a/mlir/lib/IR/CMakeLists.txt +++ b/mlir/lib/IR/CMakeLists.txt @@ -10,6 +10,7 @@ Dominance.cpp Function.cpp FunctionImplementation.cpp + FunctionSupport.cpp IntegerSet.cpp Location.cpp MLIRContext.cpp diff --git a/mlir/lib/IR/Function.cpp b/mlir/lib/IR/Function.cpp --- a/mlir/lib/IR/Function.cpp +++ b/mlir/lib/IR/Function.cpp @@ -98,65 +98,6 @@ return success(); } -void FuncOp::eraseArguments(ArrayRef argIndices) { - auto oldType = getType(); - int originalNumArgs = oldType.getNumInputs(); - llvm::BitVector eraseIndices(originalNumArgs); - for (auto index : argIndices) - eraseIndices.set(index); - auto shouldEraseArg = [&](int i) { return eraseIndices.test(i); }; - - // There are 3 things that need to be updated: - // - Function type. - // - Arg attrs. - // - Block arguments of entry block. - - // Update the function type and arg attrs. - SmallVector newInputTypes; - SmallVector newArgAttrs; - for (int i = 0; i < originalNumArgs; i++) { - if (shouldEraseArg(i)) - continue; - newInputTypes.emplace_back(oldType.getInput(i)); - newArgAttrs.emplace_back(getArgAttrDict(i)); - } - setType(FunctionType::get(newInputTypes, oldType.getResults(), getContext())); - setAllArgAttrs(newArgAttrs); - - // Update the entry block's arguments. - // We do this in reverse so that we erase later indices before earlier - // indices, to avoid shifting the later indices. - Block &entry = front(); - for (int i = 0; i < originalNumArgs; i++) - if (shouldEraseArg(originalNumArgs - i - 1)) - entry.eraseArgument(originalNumArgs - i - 1); -} - -void FuncOp::eraseResults(ArrayRef resultIndices) { - auto oldType = getType(); - int originalNumResults = oldType.getNumResults(); - llvm::BitVector eraseIndices(originalNumResults); - for (auto index : resultIndices) - eraseIndices.set(index); - auto shouldEraseResult = [&](int i) { return eraseIndices.test(i); }; - - // There are 2 things that need to be updated: - // - Function type. - // - Result attrs. - - // Update the function type and result attrs. - SmallVector newResultTypes; - SmallVector newResultAttrs; - for (int i = 0; i < originalNumResults; i++) { - if (shouldEraseResult(i)) - continue; - newResultTypes.emplace_back(oldType.getResult(i)); - newResultAttrs.emplace_back(getResultAttrDict(i)); - } - setType(FunctionType::get(oldType.getInputs(), newResultTypes, getContext())); - setAllResultAttrs(newResultAttrs); -} - /// Clone the internal blocks from this function into dest and all attributes /// from this function to dest. void FuncOp::cloneInto(FuncOp dest, BlockAndValueMapping &mapper) { diff --git a/mlir/lib/IR/FunctionSupport.cpp b/mlir/lib/IR/FunctionSupport.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/IR/FunctionSupport.cpp @@ -0,0 +1,103 @@ +//===- FunctionSupport.cpp - Utility types for function-like ops ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/FunctionSupport.h" +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/BitVector.h" + +using namespace mlir; + +/// Helper to call a callback once on each index in the range +/// [0, `totalIndices`), *except* for the indices given in `indices`. +/// `indices` is allowed to have duplicates and can be in any order. +inline void iterateIndicesExcept(unsigned totalIndices, + ArrayRef indices, + function_ref callback) { + llvm::BitVector skipIndices(totalIndices); + for (unsigned i : indices) + skipIndices.set(i); + + for (unsigned i = 0; i < totalIndices; ++i) + if (!skipIndices.test(i)) + callback(i); +} + +//===----------------------------------------------------------------------===// +// Function Arguments and Results. +//===----------------------------------------------------------------------===// + +void mlir::impl::eraseFunctionArguments(Operation *op, + ArrayRef argIndices, + unsigned originalNumArgs, + Type newType) { + // There are 3 things that need to be updated: + // - Function type. + // - Arg attrs. + // - Block arguments of entry block. + Block &entry = op->getRegion(0).front(); + SmallString<8> nameBuf; + + // Collect arg attrs to set. + SmallVector newArgAttrs; + iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) { + newArgAttrs.emplace_back(getArgAttrDict(op, i)); + }); + + // Remove any arg attrs that are no longer needed. + for (unsigned i = newArgAttrs.size(), e = originalNumArgs; i < e; ++i) + op->removeAttr(getArgAttrName(i, nameBuf)); + + // Set the function type. + op->setAttr(getTypeAttrName(), TypeAttr::get(newType)); + + // Set the new arg attrs, or remove them if empty. + for (unsigned i = 0, e = newArgAttrs.size(); i != e; ++i) { + auto nameAttr = getArgAttrName(i, nameBuf); + auto argAttr = newArgAttrs[i]; + if (argAttr.empty()) + op->removeAttr(nameAttr); + else + op->setAttr(nameAttr, argAttr.getDictionary(op->getContext())); + } + + // Update the entry block's arguments. + entry.eraseArguments(argIndices); +} + +void mlir::impl::eraseFunctionResults(Operation *op, + ArrayRef resultIndices, + unsigned originalNumResults, + Type newType) { + // There are 2 things that need to be updated: + // - Function type. + // - Result attrs. + SmallString<8> nameBuf; + + // Collect result attrs to set. + SmallVector newResultAttrs; + iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) { + newResultAttrs.emplace_back(getResultAttrDict(op, i)); + }); + + // Remove any result attrs that are no longer needed. + for (unsigned i = newResultAttrs.size(), e = originalNumResults; i < e; ++i) + op->removeAttr(getResultAttrName(i, nameBuf)); + + // Set the function type. + op->setAttr(getTypeAttrName(), TypeAttr::get(newType)); + + // Set the new result attrs, or remove them if empty. + for (unsigned i = 0, e = newResultAttrs.size(); i != e; ++i) { + auto nameAttr = getResultAttrName(i, nameBuf); + auto resultAttr = newResultAttrs[i]; + if (resultAttr.empty()) + op->removeAttr(nameAttr); + else + op->setAttr(nameAttr, resultAttr.getDictionary(op->getContext())); + } +} diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp --- a/mlir/lib/IR/Types.cpp +++ b/mlir/lib/IR/Types.cpp @@ -10,6 +10,8 @@ #include "TypeDetail.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/BitVector.h" #include "llvm/ADT/Twine.h" using namespace mlir; @@ -46,6 +48,48 @@ return getImpl()->getResults(); } +/// Helper to call a callback once on each index in the range +/// [0, `totalIndices`), *except* for the indices given in `indices`. +/// `indices` is allowed to have duplicates and can be in any order. +inline void iterateIndicesExcept(unsigned totalIndices, + ArrayRef indices, + function_ref callback) { + llvm::BitVector skipIndices(totalIndices); + for (unsigned i : indices) + skipIndices.set(i); + + for (unsigned i = 0; i < totalIndices; ++i) + if (!skipIndices.test(i)) + callback(i); +} + +/// Returns a new function type without the specified arguments and results. +FunctionType +FunctionType::getWithoutArgsAndResults(ArrayRef argIndices, + ArrayRef resultIndices) { + ArrayRef newInputTypes = getInputs(); + SmallVector newInputTypesBuffer; + if (!argIndices.empty()) { + unsigned originalNumArgs = getNumInputs(); + iterateIndicesExcept(originalNumArgs, argIndices, [&](unsigned i) { + newInputTypesBuffer.emplace_back(getInput(i)); + }); + newInputTypes = newInputTypesBuffer; + } + + ArrayRef newResultTypes = getResults(); + SmallVector newResultTypesBuffer; + if (!resultIndices.empty()) { + unsigned originalNumResults = getNumResults(); + iterateIndicesExcept(originalNumResults, resultIndices, [&](unsigned i) { + newResultTypesBuffer.emplace_back(getResult(i)); + }); + newResultTypes = newResultTypesBuffer; + } + + return get(newInputTypes, newResultTypes, getContext()); +} + //===----------------------------------------------------------------------===// // OpaqueType //===----------------------------------------------------------------------===//